isa-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +359 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +101 -0
- isa_model/inference/providers/replicate_provider.py +107 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/openai_tts_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
- isa_model/inference/services/embedding/openai_embed_service.py +0 -0
- isa_model/inference/services/llm/__init__.py +12 -0
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +99 -0
- isa_model/inference/services/llm/openai_llm_service.py +138 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/openai_vision_service.py +80 -0
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.0.1.dist-info/METADATA +327 -0
- isa_model-0.0.1.dist-info/RECORD +86 -0
- isa_model-0.0.1.dist-info/WHEEL +5 -0
- isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
- isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Dict, Any, List, Union, Optional
|
2
|
+
from ...base_service import BaseService
|
3
|
+
from ...base_provider import BaseProvider
|
4
|
+
from transformers import TableTransformerForObjectDetection, DetrImageProcessor
|
5
|
+
import torch
|
6
|
+
from PIL import Image
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
class TableTransformerService(BaseService):
|
10
|
+
"""Table detection service using Microsoft's Table Transformer"""
|
11
|
+
|
12
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "microsoft/table-transformer-detection"):
|
13
|
+
super().__init__(provider, model_name)
|
14
|
+
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
15
|
+
self.model = TableTransformerForObjectDetection.from_pretrained(model_name)
|
16
|
+
if torch.cuda.is_available():
|
17
|
+
self.model = self.model.cuda()
|
18
|
+
|
19
|
+
async def detect_tables(self, image_path: str) -> Dict[str, Any]:
|
20
|
+
"""Detect tables in image"""
|
21
|
+
try:
|
22
|
+
# Load and process image
|
23
|
+
image = Image.open(image_path)
|
24
|
+
inputs = self.processor(images=image, return_tensors="pt")
|
25
|
+
|
26
|
+
if torch.cuda.is_available():
|
27
|
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
28
|
+
|
29
|
+
# Run inference
|
30
|
+
outputs = self.model(**inputs)
|
31
|
+
|
32
|
+
# Convert outputs to image size
|
33
|
+
target_sizes = torch.tensor([image.size[::-1]])
|
34
|
+
results = self.processor.post_process_object_detection(
|
35
|
+
outputs, target_sizes=target_sizes, threshold=0.7
|
36
|
+
)[0]
|
37
|
+
|
38
|
+
tables = []
|
39
|
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
40
|
+
if label == 1: # Table class
|
41
|
+
tables.append({
|
42
|
+
"confidence": score.item(),
|
43
|
+
"bbox": box.tolist(),
|
44
|
+
"type": "table"
|
45
|
+
})
|
46
|
+
|
47
|
+
return {
|
48
|
+
"tables": tables,
|
49
|
+
"image_size": image.size
|
50
|
+
}
|
51
|
+
|
52
|
+
except Exception as e:
|
53
|
+
raise RuntimeError(f"Table detection failed: {e}")
|
54
|
+
|
55
|
+
async def close(self):
|
56
|
+
"""Cleanup resources"""
|
57
|
+
if hasattr(self, 'model'):
|
58
|
+
del self.model
|
59
|
+
if hasattr(self, 'processor'):
|
60
|
+
del self.processor
|
61
|
+
torch.cuda.empty_cache()
|
@@ -0,0 +1,12 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Vision服务包
|
6
|
+
包含所有视觉相关服务模块
|
7
|
+
"""
|
8
|
+
|
9
|
+
# 导出ReplicateVisionService
|
10
|
+
from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
|
11
|
+
|
12
|
+
__all__ = ["ReplicateVisionService"]
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from io import BytesIO
|
2
|
+
from PIL import Image
|
3
|
+
from typing import Union
|
4
|
+
import base64
|
5
|
+
from app.config.config_manager import config_manager
|
6
|
+
|
7
|
+
logger = config_manager.get_logger(__name__)
|
8
|
+
|
9
|
+
def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> bytes:
|
10
|
+
"""压缩图片以减小大小
|
11
|
+
|
12
|
+
Args:
|
13
|
+
image_data: 图片数据,可以是 bytes 或 BytesIO
|
14
|
+
max_size: 最大尺寸(像素)
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
bytes: 压缩后的图片数据
|
18
|
+
"""
|
19
|
+
try:
|
20
|
+
# 如果输入是 bytes,转换为 BytesIO
|
21
|
+
if isinstance(image_data, bytes):
|
22
|
+
image_data = BytesIO(image_data)
|
23
|
+
|
24
|
+
img = Image.open(image_data)
|
25
|
+
|
26
|
+
# 转换为 RGB 模式(如果需要)
|
27
|
+
if img.mode in ('RGBA', 'P'):
|
28
|
+
img = img.convert('RGB')
|
29
|
+
|
30
|
+
# 计算新尺寸,保持宽高比
|
31
|
+
ratio = max_size / max(img.size)
|
32
|
+
if ratio < 1:
|
33
|
+
new_size = tuple(int(dim * ratio) for dim in img.size)
|
34
|
+
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
35
|
+
|
36
|
+
# 保存压缩后的图片
|
37
|
+
output = BytesIO()
|
38
|
+
img.save(output, format='JPEG', quality=85, optimize=True)
|
39
|
+
return output.getvalue()
|
40
|
+
|
41
|
+
except Exception as e:
|
42
|
+
logger.error(f"Error compressing image: {e}")
|
43
|
+
raise
|
44
|
+
|
45
|
+
def encode_image_to_base64(image_data: bytes) -> str:
|
46
|
+
"""将图片数据编码为 base64 字符串
|
47
|
+
|
48
|
+
Args:
|
49
|
+
image_data: 图片二进制数据
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
str: base64 编码的字符串
|
53
|
+
"""
|
54
|
+
try:
|
55
|
+
return base64.b64encode(image_data).decode('utf-8')
|
56
|
+
except Exception as e:
|
57
|
+
logger.error(f"Error encoding image to base64: {e}")
|
58
|
+
raise
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from typing import Dict, List, Optional
|
2
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
class TextChunkHelper:
|
6
|
+
"""Text splitting and chunking helper"""
|
7
|
+
|
8
|
+
def __init__(self,
|
9
|
+
chunk_size: int = 512,
|
10
|
+
chunk_overlap: int = 50,
|
11
|
+
min_chunk_size: int = 50):
|
12
|
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
13
|
+
chunk_size=chunk_size,
|
14
|
+
chunk_overlap=chunk_overlap,
|
15
|
+
length_function=len,
|
16
|
+
separators=["\n\n", "\n", ". ", ", ", " "]
|
17
|
+
)
|
18
|
+
self.min_chunk_size = min_chunk_size
|
19
|
+
|
20
|
+
def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
21
|
+
"""Create text chunks with metadata"""
|
22
|
+
if not text or not isinstance(text, str):
|
23
|
+
raise ValueError("Text must be a non-empty string")
|
24
|
+
|
25
|
+
chunks = self.text_splitter.split_text(text)
|
26
|
+
valid_chunks = [
|
27
|
+
chunk for chunk in chunks
|
28
|
+
if len(chunk) >= self.min_chunk_size
|
29
|
+
]
|
30
|
+
|
31
|
+
results = []
|
32
|
+
for i, chunk in enumerate(valid_chunks):
|
33
|
+
chunk_data = {
|
34
|
+
"chunk_id": f"chunk_{uuid4().hex[:8]}",
|
35
|
+
"content": chunk,
|
36
|
+
"token_count": len(chunk),
|
37
|
+
"metadata": {
|
38
|
+
**(metadata or {}),
|
39
|
+
"position": i,
|
40
|
+
"start_idx": text.find(chunk),
|
41
|
+
"end_idx": text.find(chunk) + len(chunk)
|
42
|
+
}
|
43
|
+
}
|
44
|
+
results.append(chunk_data)
|
45
|
+
|
46
|
+
return results
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import base64
|
4
|
+
import ollama
|
5
|
+
from typing import Dict, Any, Union
|
6
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
7
|
+
from isa_model.inference.services.base_service import BaseService
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
import logging
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
class OllamaVisionService(BaseService):
|
14
|
+
"""Vision model service wrapper for Ollama using base64 encoded images"""
|
15
|
+
|
16
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = 'gemma3:4b'):
|
17
|
+
super().__init__(provider, model_name)
|
18
|
+
self.max_tokens = self.config.get('max_tokens', 1000)
|
19
|
+
self.temperature = self.config.get('temperature', 0.7)
|
20
|
+
|
21
|
+
@retry(
|
22
|
+
stop=stop_after_attempt(3),
|
23
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
24
|
+
reraise=True
|
25
|
+
)
|
26
|
+
async def analyze_image(self, image_data: Union[bytes, str], query: str) -> str:
|
27
|
+
"""分析图片并返回结果
|
28
|
+
|
29
|
+
Args:
|
30
|
+
image_data: 图片数据,可以是 bytes 或图片路径字符串
|
31
|
+
query: 查询文本
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
str: 分析结果
|
35
|
+
"""
|
36
|
+
try:
|
37
|
+
# 如果是文件路径,读取文件内容
|
38
|
+
if isinstance(image_data, str):
|
39
|
+
with open(image_data, 'rb') as f:
|
40
|
+
image_data = f.read()
|
41
|
+
|
42
|
+
# 转换为base64
|
43
|
+
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
44
|
+
|
45
|
+
# 使用 ollama 库直接调用
|
46
|
+
response = ollama.chat(
|
47
|
+
model=self.model_name,
|
48
|
+
messages=[{
|
49
|
+
'role': 'user',
|
50
|
+
'content': query,
|
51
|
+
'images': [image_base64]
|
52
|
+
}]
|
53
|
+
)
|
54
|
+
|
55
|
+
return response['message']['content']
|
56
|
+
|
57
|
+
except Exception as e:
|
58
|
+
logger.error(f"Error in image analysis: {e}")
|
59
|
+
raise
|
60
|
+
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from typing import Dict, Any, Union
|
2
|
+
from openai import AsyncOpenAI
|
3
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
4
|
+
from isa_model.inference.services.base_service import BaseService
|
5
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
6
|
+
from .helpers.image_utils import compress_image, encode_image_to_base64
|
7
|
+
import logging
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
class OpenAIVisionService(BaseService):
|
12
|
+
"""Vision model service wrapper for YYDS"""
|
13
|
+
|
14
|
+
def __init__(self, provider: 'BaseProvider', model_name: str):
|
15
|
+
super().__init__(provider, model_name)
|
16
|
+
# 初始化 AsyncOpenAI 客户端
|
17
|
+
self._client = AsyncOpenAI(
|
18
|
+
api_key=self.config.get('api_key'),
|
19
|
+
base_url=self.config.get('base_url')
|
20
|
+
)
|
21
|
+
self.max_tokens = self.config.get('max_tokens', 1000)
|
22
|
+
self.temperature = self.config.get('temperature', 0.7)
|
23
|
+
|
24
|
+
@property
|
25
|
+
def client(self) -> AsyncOpenAI:
|
26
|
+
"""获取底层的 OpenAI 客户端"""
|
27
|
+
return self._client
|
28
|
+
|
29
|
+
@retry(
|
30
|
+
stop=stop_after_attempt(3),
|
31
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
32
|
+
reraise=True
|
33
|
+
)
|
34
|
+
async def analyze_image(self, image_data: Union[bytes, str], query: str) -> str:
|
35
|
+
"""分析图片并返回结果
|
36
|
+
|
37
|
+
Args:
|
38
|
+
image_data: 图片数据,可以是 bytes 或已编码的 base64 字符串
|
39
|
+
query: 查询文本
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
str: 分析结果
|
43
|
+
"""
|
44
|
+
try:
|
45
|
+
# 处理图片数据
|
46
|
+
if isinstance(image_data, bytes):
|
47
|
+
# 压缩并编码图片
|
48
|
+
compressed_image = compress_image(image_data)
|
49
|
+
image_b64 = encode_image_to_base64(compressed_image)
|
50
|
+
else:
|
51
|
+
image_b64 = image_data
|
52
|
+
|
53
|
+
# 移除可能存在的 base64 前缀
|
54
|
+
if 'base64,' in image_b64:
|
55
|
+
image_b64 = image_b64.split('base64,')[1]
|
56
|
+
|
57
|
+
# 使用 AsyncOpenAI 客户端创建请求
|
58
|
+
response = await self._client.chat.completions.create(
|
59
|
+
model=self.model_name,
|
60
|
+
messages=[{
|
61
|
+
"role": "user",
|
62
|
+
"content": [
|
63
|
+
{"type": "text", "text": query},
|
64
|
+
{
|
65
|
+
"type": "image_url",
|
66
|
+
"image_url": {
|
67
|
+
"url": f"data:image/jpeg;base64,{image_b64}"
|
68
|
+
}
|
69
|
+
}
|
70
|
+
]
|
71
|
+
}],
|
72
|
+
max_tokens=self.max_tokens,
|
73
|
+
temperature=self.temperature
|
74
|
+
)
|
75
|
+
|
76
|
+
return response.choices[0].message.content
|
77
|
+
|
78
|
+
except Exception as e:
|
79
|
+
logger.error(f"Error in image analysis: {e}")
|
80
|
+
raise
|
@@ -0,0 +1,185 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Replicate Vision服务
|
6
|
+
用于与Replicate API交互,支持图像生成和图像分析
|
7
|
+
"""
|
8
|
+
|
9
|
+
import os
|
10
|
+
import time
|
11
|
+
import uuid
|
12
|
+
import logging
|
13
|
+
from typing import Dict, Any, List, Optional, Union, Tuple
|
14
|
+
import asyncio
|
15
|
+
import aiohttp
|
16
|
+
import replicate # 导入 replicate 库
|
17
|
+
from PIL import Image
|
18
|
+
from io import BytesIO
|
19
|
+
|
20
|
+
# 调整 BaseService 的导入路径以匹配您的项目结构
|
21
|
+
from isa_model.inference.services.base_service import BaseService
|
22
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
23
|
+
from isa_model.inference.base import ModelType
|
24
|
+
|
25
|
+
# 设置日志记录
|
26
|
+
logging.basicConfig(level=logging.INFO)
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
class ReplicateVisionService(BaseService):
|
30
|
+
"""
|
31
|
+
Replicate Vision服务,用于处理图像生成和分析。
|
32
|
+
经过调整,使用原生异步调用并优化了文件处理。
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, provider: BaseProvider, model_name: str):
|
36
|
+
"""
|
37
|
+
初始化Replicate Vision服务
|
38
|
+
"""
|
39
|
+
super().__init__(provider, model_name)
|
40
|
+
# 从 provider 或环境变量获取 API token
|
41
|
+
self.api_token = self.provider.config.get("api_token", os.environ.get("REPLICATE_API_TOKEN"))
|
42
|
+
self.model_type = ModelType.VISION
|
43
|
+
|
44
|
+
# 可选的默认配置
|
45
|
+
self.guidance_scale = self.provider.config.get("guidance_scale", 7.5)
|
46
|
+
self.num_inference_steps = self.provider.config.get("num_inference_steps", 30)
|
47
|
+
|
48
|
+
# 生成的图像存储目录
|
49
|
+
self.output_dir = "generated_images"
|
50
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
51
|
+
|
52
|
+
# ★ 调整点: 为 replicate 库设置 API token
|
53
|
+
if self.api_token:
|
54
|
+
# replicate 库会自动从环境变量读取,我们确保它被设置
|
55
|
+
os.environ["REPLICATE_API_TOKEN"] = self.api_token
|
56
|
+
else:
|
57
|
+
logger.warning("Replicate API token 未找到。服务可能无法正常工作。")
|
58
|
+
|
59
|
+
async def _prepare_input_files(self, input_data: Dict[str, Any]) -> Tuple[Dict[str, Any], List[Any]]:
|
60
|
+
"""
|
61
|
+
★ 新增辅助函数: 准备输入数据,将本地文件路径转换为文件对象。
|
62
|
+
这使得服务能统一处理本地文件和URL。
|
63
|
+
"""
|
64
|
+
prepared_input = input_data.copy()
|
65
|
+
files_to_close = []
|
66
|
+
for key, value in prepared_input.items():
|
67
|
+
# 如果值是字符串,且看起来像一个存在的本地文件路径
|
68
|
+
if isinstance(value, str) and not value.startswith(('http://', 'https://')) and os.path.exists(value):
|
69
|
+
logger.info(f"检测到本地文件路径 '{value}',准备打开文件。")
|
70
|
+
try:
|
71
|
+
file_handle = open(value, "rb")
|
72
|
+
prepared_input[key] = file_handle
|
73
|
+
files_to_close.append(file_handle)
|
74
|
+
except Exception as e:
|
75
|
+
logger.error(f"打开文件失败 '{value}': {e}")
|
76
|
+
raise
|
77
|
+
return prepared_input, files_to_close
|
78
|
+
|
79
|
+
async def generate_image(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
80
|
+
"""
|
81
|
+
使用Replicate模型生成图像 (已优化为原生异步)
|
82
|
+
"""
|
83
|
+
prepared_input, files_to_close = await self._prepare_input_files(input_data)
|
84
|
+
try:
|
85
|
+
# 设置默认参数
|
86
|
+
if "guidance_scale" not in prepared_input:
|
87
|
+
prepared_input["guidance_scale"] = self.guidance_scale
|
88
|
+
if "num_inference_steps" not in prepared_input:
|
89
|
+
prepared_input["num_inference_steps"] = self.num_inference_steps
|
90
|
+
|
91
|
+
logger.info(f"开始使用模型 {self.model_name} 生成图像 (原生异步)")
|
92
|
+
|
93
|
+
# ★ 调整点: 使用原生异步的 replicate.async_run
|
94
|
+
output = await replicate.async_run(self.model_name, input=prepared_input)
|
95
|
+
|
96
|
+
# 将结果转换为标准格式 (此部分逻辑无需改变)
|
97
|
+
if isinstance(output, list):
|
98
|
+
urls = output
|
99
|
+
else:
|
100
|
+
urls = [output]
|
101
|
+
|
102
|
+
result = {
|
103
|
+
"urls": urls,
|
104
|
+
"metadata": {
|
105
|
+
"model": self.model_name,
|
106
|
+
"input": input_data # 返回原始输入以供参考
|
107
|
+
}
|
108
|
+
}
|
109
|
+
logger.info(f"图像生成完成: {result['urls']}")
|
110
|
+
return result
|
111
|
+
except Exception as e:
|
112
|
+
logger.error(f"图像生成失败: {e}")
|
113
|
+
raise
|
114
|
+
finally:
|
115
|
+
# ★ 新增: 确保所有打开的文件都被关闭
|
116
|
+
for f in files_to_close:
|
117
|
+
f.close()
|
118
|
+
|
119
|
+
async def analyze_image(self, image_path: str, prompt: str) -> Dict[str, Any]:
|
120
|
+
"""
|
121
|
+
分析图像 (已优化为原生异步)
|
122
|
+
"""
|
123
|
+
input_data = {"image": image_path, "prompt": prompt}
|
124
|
+
prepared_input, files_to_close = await self._prepare_input_files(input_data)
|
125
|
+
try:
|
126
|
+
logger.info(f"开始使用模型 {self.model_name} 分析图像 (原生异步)")
|
127
|
+
# ★ 调整点: 使用原生异步的 replicate.async_run
|
128
|
+
output = await replicate.async_run(self.model_name, input=prepared_input)
|
129
|
+
|
130
|
+
result = {
|
131
|
+
"text": "".join(output) if isinstance(output, list) else output,
|
132
|
+
"metadata": {
|
133
|
+
"model": self.model_name,
|
134
|
+
"input": input_data
|
135
|
+
}
|
136
|
+
}
|
137
|
+
logger.info(f"图像分析完成")
|
138
|
+
return result
|
139
|
+
except Exception as e:
|
140
|
+
logger.error(f"图像分析失败: {e}")
|
141
|
+
raise
|
142
|
+
finally:
|
143
|
+
# ★ 新增: 确保所有打开的文件都被关闭
|
144
|
+
for f in files_to_close:
|
145
|
+
f.close()
|
146
|
+
|
147
|
+
async def generate_and_save(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
148
|
+
"""生成图像并保存到本地 (此方法无需修改)"""
|
149
|
+
result = await self.generate_image(input_data)
|
150
|
+
saved_paths = []
|
151
|
+
for i, url in enumerate(result["urls"]):
|
152
|
+
timestamp = int(time.time())
|
153
|
+
file_name = f"{self.output_dir}/{timestamp}_{uuid.uuid4().hex[:8]}_{i+1}.png"
|
154
|
+
try:
|
155
|
+
# Convert FileOutput object to string if necessary
|
156
|
+
url_str = str(url) if hasattr(url, "__str__") else url
|
157
|
+
await self._download_image(url_str, file_name)
|
158
|
+
saved_paths.append(file_name)
|
159
|
+
logger.info(f"图像已保存至: {file_name}")
|
160
|
+
except Exception as e:
|
161
|
+
logger.error(f"保存图像失败: {e}")
|
162
|
+
result["saved_paths"] = saved_paths
|
163
|
+
return result
|
164
|
+
|
165
|
+
async def _download_image(self, url: str, save_path: str) -> None:
|
166
|
+
"""异步下载图像并保存 (此方法无需修改)"""
|
167
|
+
try:
|
168
|
+
async with aiohttp.ClientSession() as session:
|
169
|
+
async with session.get(url) as response:
|
170
|
+
response.raise_for_status()
|
171
|
+
content = await response.read()
|
172
|
+
with Image.open(BytesIO(content)) as img:
|
173
|
+
img.save(save_path)
|
174
|
+
except Exception as e:
|
175
|
+
logger.error(f"下载图像时出错: {url}, {e}")
|
176
|
+
raise
|
177
|
+
|
178
|
+
# `load` 和 `unload` 方法在Replicate API场景下通常是轻量级的
|
179
|
+
async def load(self) -> None:
|
180
|
+
if not self.api_token:
|
181
|
+
raise ValueError("缺少Replicate API令牌,请设置REPLICATE_API_TOKEN环境变量或在provider配置中提供")
|
182
|
+
logger.info(f"Replicate Vision服务已准备就绪,使用模型: {self.model_name}")
|
183
|
+
|
184
|
+
async def unload(self) -> None:
|
185
|
+
logger.info(f"卸载Replicate Vision服务: {self.model_name}")
|
@@ -0,0 +1,73 @@
|
|
1
|
+
import os
|
2
|
+
import torch
|
3
|
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
def convert_bge_to_onnx(save_dir: str):
|
7
|
+
"""Convert BGE reranker to ONNX format"""
|
8
|
+
try:
|
9
|
+
# Create save directory if it doesn't exist
|
10
|
+
save_dir = Path(save_dir).resolve() # Get absolute path
|
11
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
12
|
+
|
13
|
+
model_name = "BAAI/bge-reranker-v2-m3"
|
14
|
+
save_path = str(save_dir / "model.onnx") # Convert to string for absolute path
|
15
|
+
|
16
|
+
print(f"Loading model {model_name}...")
|
17
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18
|
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
19
|
+
model.eval()
|
20
|
+
|
21
|
+
# Save tokenizer for later use
|
22
|
+
print("Saving tokenizer...")
|
23
|
+
tokenizer.save_pretrained(save_dir)
|
24
|
+
|
25
|
+
# Create dummy input
|
26
|
+
print("Creating dummy input...")
|
27
|
+
dummy_input = tokenizer(
|
28
|
+
[["what is panda?", "The giant panda is a bear species."]],
|
29
|
+
padding=True,
|
30
|
+
truncation=True,
|
31
|
+
return_tensors='pt',
|
32
|
+
max_length=512
|
33
|
+
)
|
34
|
+
|
35
|
+
# Export to ONNX with external data storage
|
36
|
+
print(f"Exporting to ONNX: {save_path}")
|
37
|
+
torch.onnx.export(
|
38
|
+
model,
|
39
|
+
(dummy_input['input_ids'], dummy_input['attention_mask']),
|
40
|
+
save_path, # Using string absolute path
|
41
|
+
input_names=['input_ids', 'attention_mask'],
|
42
|
+
output_names=['logits'],
|
43
|
+
dynamic_axes={
|
44
|
+
'input_ids': {0: 'batch', 1: 'sequence'},
|
45
|
+
'attention_mask': {0: 'batch', 1: 'sequence'},
|
46
|
+
'logits': {0: 'batch'}
|
47
|
+
},
|
48
|
+
opset_version=16,
|
49
|
+
export_params=True, # Export the trained parameter weights
|
50
|
+
do_constant_folding=True, # Optimize constant-folding
|
51
|
+
verbose=True,
|
52
|
+
use_external_data_format=True # Enable external data storage
|
53
|
+
)
|
54
|
+
print("Conversion completed successfully!")
|
55
|
+
return True
|
56
|
+
|
57
|
+
except Exception as e:
|
58
|
+
print(f"Error during conversion: {e}")
|
59
|
+
return False
|
60
|
+
|
61
|
+
if __name__ == "__main__":
|
62
|
+
# Get the absolute path to the model directory
|
63
|
+
current_dir = Path(__file__).parent.parent
|
64
|
+
model_dir = current_dir / "model_converted" / "bge-reranker-v2-m3"
|
65
|
+
|
66
|
+
success = convert_bge_to_onnx(str(model_dir))
|
67
|
+
if success:
|
68
|
+
print(f"Model saved to: {model_dir}")
|
69
|
+
print("Files created:")
|
70
|
+
for file in model_dir.glob('*'):
|
71
|
+
print(f"- {file.name}")
|
72
|
+
else:
|
73
|
+
print("Conversion failed!")
|
File without changes
|
File without changes
|