isa-model 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  12. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  13. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  14. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  15. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  16. isa_model/inference/__init__.py +11 -0
  17. isa_model/inference/adapter/unified_api.py +248 -0
  18. isa_model/inference/ai_factory.py +359 -0
  19. isa_model/inference/base.py +46 -0
  20. isa_model/inference/providers/__init__.py +19 -0
  21. isa_model/inference/providers/base_provider.py +30 -0
  22. isa_model/inference/providers/model_cache_manager.py +341 -0
  23. isa_model/inference/providers/ollama_provider.py +73 -0
  24. isa_model/inference/providers/openai_provider.py +101 -0
  25. isa_model/inference/providers/replicate_provider.py +107 -0
  26. isa_model/inference/providers/triton_provider.py +439 -0
  27. isa_model/inference/services/__init__.py +14 -0
  28. isa_model/inference/services/audio/base_stt_service.py +91 -0
  29. isa_model/inference/services/audio/base_tts_service.py +136 -0
  30. isa_model/inference/services/audio/openai_tts_service.py +71 -0
  31. isa_model/inference/services/base_service.py +106 -0
  32. isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
  33. isa_model/inference/services/embedding/openai_embed_service.py +0 -0
  34. isa_model/inference/services/llm/__init__.py +12 -0
  35. isa_model/inference/services/llm/base_llm_service.py +134 -0
  36. isa_model/inference/services/llm/ollama_llm_service.py +99 -0
  37. isa_model/inference/services/llm/openai_llm_service.py +138 -0
  38. isa_model/inference/services/others/table_transformer_service.py +61 -0
  39. isa_model/inference/services/vision/__init__.py +12 -0
  40. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  41. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  42. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  43. isa_model/inference/services/vision/openai_vision_service.py +80 -0
  44. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  45. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  46. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  47. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  48. isa_model/scripts/inference_tracker.py +283 -0
  49. isa_model/scripts/mlflow_manager.py +379 -0
  50. isa_model/scripts/model_registry.py +465 -0
  51. isa_model/scripts/start_mlflow.py +95 -0
  52. isa_model/scripts/training_tracker.py +257 -0
  53. isa_model/training/engine/llama_factory/__init__.py +39 -0
  54. isa_model/training/engine/llama_factory/config.py +115 -0
  55. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  56. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  57. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  58. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  59. isa_model/training/engine/llama_factory/factory.py +331 -0
  60. isa_model/training/engine/llama_factory/rl.py +254 -0
  61. isa_model/training/engine/llama_factory/trainer.py +171 -0
  62. isa_model/training/image_model/configs/create_config.py +37 -0
  63. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  64. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  65. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  66. isa_model/training/image_model/prepare_upload.py +17 -0
  67. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  68. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  69. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  70. isa_model/training/image_model/train/train.py +42 -0
  71. isa_model/training/image_model/train/train_flux.py +41 -0
  72. isa_model/training/image_model/train/train_lora.py +57 -0
  73. isa_model/training/image_model/train_main.py +25 -0
  74. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  75. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  76. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  77. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  78. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  79. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  80. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  81. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  82. isa_model-0.0.1.dist-info/METADATA +327 -0
  83. isa_model-0.0.1.dist-info/RECORD +86 -0
  84. isa_model-0.0.1.dist-info/WHEEL +5 -0
  85. isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
  86. isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
1
+ from typing import Dict, Any, List, Union, Optional
2
+ from ...base_service import BaseService
3
+ from ...base_provider import BaseProvider
4
+ from transformers import TableTransformerForObjectDetection, DetrImageProcessor
5
+ import torch
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ class TableTransformerService(BaseService):
10
+ """Table detection service using Microsoft's Table Transformer"""
11
+
12
+ def __init__(self, provider: 'BaseProvider', model_name: str = "microsoft/table-transformer-detection"):
13
+ super().__init__(provider, model_name)
14
+ self.processor = DetrImageProcessor.from_pretrained(model_name)
15
+ self.model = TableTransformerForObjectDetection.from_pretrained(model_name)
16
+ if torch.cuda.is_available():
17
+ self.model = self.model.cuda()
18
+
19
+ async def detect_tables(self, image_path: str) -> Dict[str, Any]:
20
+ """Detect tables in image"""
21
+ try:
22
+ # Load and process image
23
+ image = Image.open(image_path)
24
+ inputs = self.processor(images=image, return_tensors="pt")
25
+
26
+ if torch.cuda.is_available():
27
+ inputs = {k: v.cuda() for k, v in inputs.items()}
28
+
29
+ # Run inference
30
+ outputs = self.model(**inputs)
31
+
32
+ # Convert outputs to image size
33
+ target_sizes = torch.tensor([image.size[::-1]])
34
+ results = self.processor.post_process_object_detection(
35
+ outputs, target_sizes=target_sizes, threshold=0.7
36
+ )[0]
37
+
38
+ tables = []
39
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
40
+ if label == 1: # Table class
41
+ tables.append({
42
+ "confidence": score.item(),
43
+ "bbox": box.tolist(),
44
+ "type": "table"
45
+ })
46
+
47
+ return {
48
+ "tables": tables,
49
+ "image_size": image.size
50
+ }
51
+
52
+ except Exception as e:
53
+ raise RuntimeError(f"Table detection failed: {e}")
54
+
55
+ async def close(self):
56
+ """Cleanup resources"""
57
+ if hasattr(self, 'model'):
58
+ del self.model
59
+ if hasattr(self, 'processor'):
60
+ del self.processor
61
+ torch.cuda.empty_cache()
@@ -0,0 +1,12 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Vision服务包
6
+ 包含所有视觉相关服务模块
7
+ """
8
+
9
+ # 导出ReplicateVisionService
10
+ from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
11
+
12
+ __all__ = ["ReplicateVisionService"]
@@ -0,0 +1,58 @@
1
+ from io import BytesIO
2
+ from PIL import Image
3
+ from typing import Union
4
+ import base64
5
+ from app.config.config_manager import config_manager
6
+
7
+ logger = config_manager.get_logger(__name__)
8
+
9
+ def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> bytes:
10
+ """压缩图片以减小大小
11
+
12
+ Args:
13
+ image_data: 图片数据,可以是 bytes 或 BytesIO
14
+ max_size: 最大尺寸(像素)
15
+
16
+ Returns:
17
+ bytes: 压缩后的图片数据
18
+ """
19
+ try:
20
+ # 如果输入是 bytes,转换为 BytesIO
21
+ if isinstance(image_data, bytes):
22
+ image_data = BytesIO(image_data)
23
+
24
+ img = Image.open(image_data)
25
+
26
+ # 转换为 RGB 模式(如果需要)
27
+ if img.mode in ('RGBA', 'P'):
28
+ img = img.convert('RGB')
29
+
30
+ # 计算新尺寸,保持宽高比
31
+ ratio = max_size / max(img.size)
32
+ if ratio < 1:
33
+ new_size = tuple(int(dim * ratio) for dim in img.size)
34
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
35
+
36
+ # 保存压缩后的图片
37
+ output = BytesIO()
38
+ img.save(output, format='JPEG', quality=85, optimize=True)
39
+ return output.getvalue()
40
+
41
+ except Exception as e:
42
+ logger.error(f"Error compressing image: {e}")
43
+ raise
44
+
45
+ def encode_image_to_base64(image_data: bytes) -> str:
46
+ """将图片数据编码为 base64 字符串
47
+
48
+ Args:
49
+ image_data: 图片二进制数据
50
+
51
+ Returns:
52
+ str: base64 编码的字符串
53
+ """
54
+ try:
55
+ return base64.b64encode(image_data).decode('utf-8')
56
+ except Exception as e:
57
+ logger.error(f"Error encoding image to base64: {e}")
58
+ raise
@@ -0,0 +1,46 @@
1
+ from typing import Dict, List, Optional
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from uuid import uuid4
4
+
5
+ class TextChunkHelper:
6
+ """Text splitting and chunking helper"""
7
+
8
+ def __init__(self,
9
+ chunk_size: int = 512,
10
+ chunk_overlap: int = 50,
11
+ min_chunk_size: int = 50):
12
+ self.text_splitter = RecursiveCharacterTextSplitter(
13
+ chunk_size=chunk_size,
14
+ chunk_overlap=chunk_overlap,
15
+ length_function=len,
16
+ separators=["\n\n", "\n", ". ", ", ", " "]
17
+ )
18
+ self.min_chunk_size = min_chunk_size
19
+
20
+ def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
21
+ """Create text chunks with metadata"""
22
+ if not text or not isinstance(text, str):
23
+ raise ValueError("Text must be a non-empty string")
24
+
25
+ chunks = self.text_splitter.split_text(text)
26
+ valid_chunks = [
27
+ chunk for chunk in chunks
28
+ if len(chunk) >= self.min_chunk_size
29
+ ]
30
+
31
+ results = []
32
+ for i, chunk in enumerate(valid_chunks):
33
+ chunk_data = {
34
+ "chunk_id": f"chunk_{uuid4().hex[:8]}",
35
+ "content": chunk,
36
+ "token_count": len(chunk),
37
+ "metadata": {
38
+ **(metadata or {}),
39
+ "position": i,
40
+ "start_idx": text.find(chunk),
41
+ "end_idx": text.find(chunk) + len(chunk)
42
+ }
43
+ }
44
+ results.append(chunk_data)
45
+
46
+ return results
@@ -0,0 +1,60 @@
1
+ import os
2
+ import json
3
+ import base64
4
+ import ollama
5
+ from typing import Dict, Any, Union
6
+ from tenacity import retry, stop_after_attempt, wait_exponential
7
+ from isa_model.inference.services.base_service import BaseService
8
+ from isa_model.inference.providers.base_provider import BaseProvider
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class OllamaVisionService(BaseService):
14
+ """Vision model service wrapper for Ollama using base64 encoded images"""
15
+
16
+ def __init__(self, provider: 'BaseProvider', model_name: str = 'gemma3:4b'):
17
+ super().__init__(provider, model_name)
18
+ self.max_tokens = self.config.get('max_tokens', 1000)
19
+ self.temperature = self.config.get('temperature', 0.7)
20
+
21
+ @retry(
22
+ stop=stop_after_attempt(3),
23
+ wait=wait_exponential(multiplier=1, min=4, max=10),
24
+ reraise=True
25
+ )
26
+ async def analyze_image(self, image_data: Union[bytes, str], query: str) -> str:
27
+ """分析图片并返回结果
28
+
29
+ Args:
30
+ image_data: 图片数据,可以是 bytes 或图片路径字符串
31
+ query: 查询文本
32
+
33
+ Returns:
34
+ str: 分析结果
35
+ """
36
+ try:
37
+ # 如果是文件路径,读取文件内容
38
+ if isinstance(image_data, str):
39
+ with open(image_data, 'rb') as f:
40
+ image_data = f.read()
41
+
42
+ # 转换为base64
43
+ image_base64 = base64.b64encode(image_data).decode('utf-8')
44
+
45
+ # 使用 ollama 库直接调用
46
+ response = ollama.chat(
47
+ model=self.model_name,
48
+ messages=[{
49
+ 'role': 'user',
50
+ 'content': query,
51
+ 'images': [image_base64]
52
+ }]
53
+ )
54
+
55
+ return response['message']['content']
56
+
57
+ except Exception as e:
58
+ logger.error(f"Error in image analysis: {e}")
59
+ raise
60
+
@@ -0,0 +1,80 @@
1
+ from typing import Dict, Any, Union
2
+ from openai import AsyncOpenAI
3
+ from tenacity import retry, stop_after_attempt, wait_exponential
4
+ from isa_model.inference.services.base_service import BaseService
5
+ from isa_model.inference.providers.base_provider import BaseProvider
6
+ from .helpers.image_utils import compress_image, encode_image_to_base64
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class OpenAIVisionService(BaseService):
12
+ """Vision model service wrapper for YYDS"""
13
+
14
+ def __init__(self, provider: 'BaseProvider', model_name: str):
15
+ super().__init__(provider, model_name)
16
+ # 初始化 AsyncOpenAI 客户端
17
+ self._client = AsyncOpenAI(
18
+ api_key=self.config.get('api_key'),
19
+ base_url=self.config.get('base_url')
20
+ )
21
+ self.max_tokens = self.config.get('max_tokens', 1000)
22
+ self.temperature = self.config.get('temperature', 0.7)
23
+
24
+ @property
25
+ def client(self) -> AsyncOpenAI:
26
+ """获取底层的 OpenAI 客户端"""
27
+ return self._client
28
+
29
+ @retry(
30
+ stop=stop_after_attempt(3),
31
+ wait=wait_exponential(multiplier=1, min=4, max=10),
32
+ reraise=True
33
+ )
34
+ async def analyze_image(self, image_data: Union[bytes, str], query: str) -> str:
35
+ """分析图片并返回结果
36
+
37
+ Args:
38
+ image_data: 图片数据,可以是 bytes 或已编码的 base64 字符串
39
+ query: 查询文本
40
+
41
+ Returns:
42
+ str: 分析结果
43
+ """
44
+ try:
45
+ # 处理图片数据
46
+ if isinstance(image_data, bytes):
47
+ # 压缩并编码图片
48
+ compressed_image = compress_image(image_data)
49
+ image_b64 = encode_image_to_base64(compressed_image)
50
+ else:
51
+ image_b64 = image_data
52
+
53
+ # 移除可能存在的 base64 前缀
54
+ if 'base64,' in image_b64:
55
+ image_b64 = image_b64.split('base64,')[1]
56
+
57
+ # 使用 AsyncOpenAI 客户端创建请求
58
+ response = await self._client.chat.completions.create(
59
+ model=self.model_name,
60
+ messages=[{
61
+ "role": "user",
62
+ "content": [
63
+ {"type": "text", "text": query},
64
+ {
65
+ "type": "image_url",
66
+ "image_url": {
67
+ "url": f"data:image/jpeg;base64,{image_b64}"
68
+ }
69
+ }
70
+ ]
71
+ }],
72
+ max_tokens=self.max_tokens,
73
+ temperature=self.temperature
74
+ )
75
+
76
+ return response.choices[0].message.content
77
+
78
+ except Exception as e:
79
+ logger.error(f"Error in image analysis: {e}")
80
+ raise
@@ -0,0 +1,185 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Replicate Vision服务
6
+ 用于与Replicate API交互,支持图像生成和图像分析
7
+ """
8
+
9
+ import os
10
+ import time
11
+ import uuid
12
+ import logging
13
+ from typing import Dict, Any, List, Optional, Union, Tuple
14
+ import asyncio
15
+ import aiohttp
16
+ import replicate # 导入 replicate 库
17
+ from PIL import Image
18
+ from io import BytesIO
19
+
20
+ # 调整 BaseService 的导入路径以匹配您的项目结构
21
+ from isa_model.inference.services.base_service import BaseService
22
+ from isa_model.inference.providers.base_provider import BaseProvider
23
+ from isa_model.inference.base import ModelType
24
+
25
+ # 设置日志记录
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ class ReplicateVisionService(BaseService):
30
+ """
31
+ Replicate Vision服务,用于处理图像生成和分析。
32
+ 经过调整,使用原生异步调用并优化了文件处理。
33
+ """
34
+
35
+ def __init__(self, provider: BaseProvider, model_name: str):
36
+ """
37
+ 初始化Replicate Vision服务
38
+ """
39
+ super().__init__(provider, model_name)
40
+ # 从 provider 或环境变量获取 API token
41
+ self.api_token = self.provider.config.get("api_token", os.environ.get("REPLICATE_API_TOKEN"))
42
+ self.model_type = ModelType.VISION
43
+
44
+ # 可选的默认配置
45
+ self.guidance_scale = self.provider.config.get("guidance_scale", 7.5)
46
+ self.num_inference_steps = self.provider.config.get("num_inference_steps", 30)
47
+
48
+ # 生成的图像存储目录
49
+ self.output_dir = "generated_images"
50
+ os.makedirs(self.output_dir, exist_ok=True)
51
+
52
+ # ★ 调整点: 为 replicate 库设置 API token
53
+ if self.api_token:
54
+ # replicate 库会自动从环境变量读取,我们确保它被设置
55
+ os.environ["REPLICATE_API_TOKEN"] = self.api_token
56
+ else:
57
+ logger.warning("Replicate API token 未找到。服务可能无法正常工作。")
58
+
59
+ async def _prepare_input_files(self, input_data: Dict[str, Any]) -> Tuple[Dict[str, Any], List[Any]]:
60
+ """
61
+ ★ 新增辅助函数: 准备输入数据,将本地文件路径转换为文件对象。
62
+ 这使得服务能统一处理本地文件和URL。
63
+ """
64
+ prepared_input = input_data.copy()
65
+ files_to_close = []
66
+ for key, value in prepared_input.items():
67
+ # 如果值是字符串,且看起来像一个存在的本地文件路径
68
+ if isinstance(value, str) and not value.startswith(('http://', 'https://')) and os.path.exists(value):
69
+ logger.info(f"检测到本地文件路径 '{value}',准备打开文件。")
70
+ try:
71
+ file_handle = open(value, "rb")
72
+ prepared_input[key] = file_handle
73
+ files_to_close.append(file_handle)
74
+ except Exception as e:
75
+ logger.error(f"打开文件失败 '{value}': {e}")
76
+ raise
77
+ return prepared_input, files_to_close
78
+
79
+ async def generate_image(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
80
+ """
81
+ 使用Replicate模型生成图像 (已优化为原生异步)
82
+ """
83
+ prepared_input, files_to_close = await self._prepare_input_files(input_data)
84
+ try:
85
+ # 设置默认参数
86
+ if "guidance_scale" not in prepared_input:
87
+ prepared_input["guidance_scale"] = self.guidance_scale
88
+ if "num_inference_steps" not in prepared_input:
89
+ prepared_input["num_inference_steps"] = self.num_inference_steps
90
+
91
+ logger.info(f"开始使用模型 {self.model_name} 生成图像 (原生异步)")
92
+
93
+ # ★ 调整点: 使用原生异步的 replicate.async_run
94
+ output = await replicate.async_run(self.model_name, input=prepared_input)
95
+
96
+ # 将结果转换为标准格式 (此部分逻辑无需改变)
97
+ if isinstance(output, list):
98
+ urls = output
99
+ else:
100
+ urls = [output]
101
+
102
+ result = {
103
+ "urls": urls,
104
+ "metadata": {
105
+ "model": self.model_name,
106
+ "input": input_data # 返回原始输入以供参考
107
+ }
108
+ }
109
+ logger.info(f"图像生成完成: {result['urls']}")
110
+ return result
111
+ except Exception as e:
112
+ logger.error(f"图像生成失败: {e}")
113
+ raise
114
+ finally:
115
+ # ★ 新增: 确保所有打开的文件都被关闭
116
+ for f in files_to_close:
117
+ f.close()
118
+
119
+ async def analyze_image(self, image_path: str, prompt: str) -> Dict[str, Any]:
120
+ """
121
+ 分析图像 (已优化为原生异步)
122
+ """
123
+ input_data = {"image": image_path, "prompt": prompt}
124
+ prepared_input, files_to_close = await self._prepare_input_files(input_data)
125
+ try:
126
+ logger.info(f"开始使用模型 {self.model_name} 分析图像 (原生异步)")
127
+ # ★ 调整点: 使用原生异步的 replicate.async_run
128
+ output = await replicate.async_run(self.model_name, input=prepared_input)
129
+
130
+ result = {
131
+ "text": "".join(output) if isinstance(output, list) else output,
132
+ "metadata": {
133
+ "model": self.model_name,
134
+ "input": input_data
135
+ }
136
+ }
137
+ logger.info(f"图像分析完成")
138
+ return result
139
+ except Exception as e:
140
+ logger.error(f"图像分析失败: {e}")
141
+ raise
142
+ finally:
143
+ # ★ 新增: 确保所有打开的文件都被关闭
144
+ for f in files_to_close:
145
+ f.close()
146
+
147
+ async def generate_and_save(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
148
+ """生成图像并保存到本地 (此方法无需修改)"""
149
+ result = await self.generate_image(input_data)
150
+ saved_paths = []
151
+ for i, url in enumerate(result["urls"]):
152
+ timestamp = int(time.time())
153
+ file_name = f"{self.output_dir}/{timestamp}_{uuid.uuid4().hex[:8]}_{i+1}.png"
154
+ try:
155
+ # Convert FileOutput object to string if necessary
156
+ url_str = str(url) if hasattr(url, "__str__") else url
157
+ await self._download_image(url_str, file_name)
158
+ saved_paths.append(file_name)
159
+ logger.info(f"图像已保存至: {file_name}")
160
+ except Exception as e:
161
+ logger.error(f"保存图像失败: {e}")
162
+ result["saved_paths"] = saved_paths
163
+ return result
164
+
165
+ async def _download_image(self, url: str, save_path: str) -> None:
166
+ """异步下载图像并保存 (此方法无需修改)"""
167
+ try:
168
+ async with aiohttp.ClientSession() as session:
169
+ async with session.get(url) as response:
170
+ response.raise_for_status()
171
+ content = await response.read()
172
+ with Image.open(BytesIO(content)) as img:
173
+ img.save(save_path)
174
+ except Exception as e:
175
+ logger.error(f"下载图像时出错: {url}, {e}")
176
+ raise
177
+
178
+ # `load` 和 `unload` 方法在Replicate API场景下通常是轻量级的
179
+ async def load(self) -> None:
180
+ if not self.api_token:
181
+ raise ValueError("缺少Replicate API令牌,请设置REPLICATE_API_TOKEN环境变量或在provider配置中提供")
182
+ logger.info(f"Replicate Vision服务已准备就绪,使用模型: {self.model_name}")
183
+
184
+ async def unload(self) -> None:
185
+ logger.info(f"卸载Replicate Vision服务: {self.model_name}")
@@ -0,0 +1,73 @@
1
+ import os
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ from pathlib import Path
5
+
6
+ def convert_bge_to_onnx(save_dir: str):
7
+ """Convert BGE reranker to ONNX format"""
8
+ try:
9
+ # Create save directory if it doesn't exist
10
+ save_dir = Path(save_dir).resolve() # Get absolute path
11
+ save_dir.mkdir(parents=True, exist_ok=True)
12
+
13
+ model_name = "BAAI/bge-reranker-v2-m3"
14
+ save_path = str(save_dir / "model.onnx") # Convert to string for absolute path
15
+
16
+ print(f"Loading model {model_name}...")
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
19
+ model.eval()
20
+
21
+ # Save tokenizer for later use
22
+ print("Saving tokenizer...")
23
+ tokenizer.save_pretrained(save_dir)
24
+
25
+ # Create dummy input
26
+ print("Creating dummy input...")
27
+ dummy_input = tokenizer(
28
+ [["what is panda?", "The giant panda is a bear species."]],
29
+ padding=True,
30
+ truncation=True,
31
+ return_tensors='pt',
32
+ max_length=512
33
+ )
34
+
35
+ # Export to ONNX with external data storage
36
+ print(f"Exporting to ONNX: {save_path}")
37
+ torch.onnx.export(
38
+ model,
39
+ (dummy_input['input_ids'], dummy_input['attention_mask']),
40
+ save_path, # Using string absolute path
41
+ input_names=['input_ids', 'attention_mask'],
42
+ output_names=['logits'],
43
+ dynamic_axes={
44
+ 'input_ids': {0: 'batch', 1: 'sequence'},
45
+ 'attention_mask': {0: 'batch', 1: 'sequence'},
46
+ 'logits': {0: 'batch'}
47
+ },
48
+ opset_version=16,
49
+ export_params=True, # Export the trained parameter weights
50
+ do_constant_folding=True, # Optimize constant-folding
51
+ verbose=True,
52
+ use_external_data_format=True # Enable external data storage
53
+ )
54
+ print("Conversion completed successfully!")
55
+ return True
56
+
57
+ except Exception as e:
58
+ print(f"Error during conversion: {e}")
59
+ return False
60
+
61
+ if __name__ == "__main__":
62
+ # Get the absolute path to the model directory
63
+ current_dir = Path(__file__).parent.parent
64
+ model_dir = current_dir / "model_converted" / "bge-reranker-v2-m3"
65
+
66
+ success = convert_bge_to_onnx(str(model_dir))
67
+ if success:
68
+ print(f"Model saved to: {model_dir}")
69
+ print("Files created:")
70
+ for file in model_dir.glob('*'):
71
+ print(f"- {file.name}")
72
+ else:
73
+ print("Conversion failed!")
File without changes