isa-model 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_manager.py +69 -4
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +427 -236
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +149 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +158 -86
- isa_model/inference/services/llm/llm_adapter.py +414 -0
- isa_model/inference/services/llm/ollama_llm_service.py +252 -63
- isa_model/inference/services/llm/openai_llm_service.py +231 -93
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +151 -17
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.3.1.dist-info/METADATA +465 -0
- isa_model-0.3.1.dist-info/RECORD +91 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.2.0.dist-info/METADATA +0 -327
- isa_model-0.2.0.dist-info/RECORD +0 -92
- isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -2,168 +2,261 @@
|
|
2
2
|
# -*- coding: utf-8 -*-
|
3
3
|
|
4
4
|
"""
|
5
|
-
Replicate
|
6
|
-
|
5
|
+
Replicate 图像生成服务
|
6
|
+
支持 flux-schnell (文生图) 和 flux-kontext-pro (图生图) 模型
|
7
7
|
"""
|
8
8
|
|
9
9
|
import os
|
10
10
|
import time
|
11
11
|
import uuid
|
12
12
|
import logging
|
13
|
-
from typing import Dict, Any, List, Optional, Union
|
13
|
+
from typing import Dict, Any, List, Optional, Union
|
14
14
|
import asyncio
|
15
15
|
import aiohttp
|
16
|
-
import replicate
|
16
|
+
import replicate
|
17
17
|
from PIL import Image
|
18
18
|
from io import BytesIO
|
19
19
|
|
20
|
-
|
21
|
-
from isa_model.inference.services.base_service import BaseService
|
20
|
+
from isa_model.inference.services.vision.base_image_gen_service import BaseImageGenService
|
22
21
|
from isa_model.inference.providers.base_provider import BaseProvider
|
23
|
-
from isa_model.inference.base import ModelType
|
24
22
|
|
25
23
|
# 设置日志记录
|
26
24
|
logging.basicConfig(level=logging.INFO)
|
27
25
|
logger = logging.getLogger(__name__)
|
28
26
|
|
29
|
-
class
|
27
|
+
class ReplicateImageGenService(BaseImageGenService):
|
30
28
|
"""
|
31
|
-
Replicate
|
32
|
-
|
29
|
+
Replicate 图像生成服务
|
30
|
+
- flux-schnell: 文生图 (t2i) - $3 per 1000 images
|
31
|
+
- flux-kontext-pro: 图生图 (i2i) - $0.04 per image
|
33
32
|
"""
|
34
33
|
|
35
34
|
def __init__(self, provider: BaseProvider, model_name: str):
|
36
|
-
"""
|
37
|
-
初始化Replicate Vision服务
|
38
|
-
"""
|
39
35
|
super().__init__(provider, model_name)
|
40
|
-
# 从 provider 或环境变量获取 API token
|
41
|
-
self.api_token = self.provider.config.get("api_token", os.environ.get("REPLICATE_API_TOKEN"))
|
42
|
-
self.model_type = ModelType.VISION
|
43
36
|
|
44
|
-
#
|
45
|
-
|
46
|
-
self.
|
37
|
+
# 获取配置
|
38
|
+
provider_config = provider.get_full_config()
|
39
|
+
self.api_token = provider_config.get("api_token") or provider_config.get("replicate_api_token")
|
47
40
|
|
48
|
-
|
41
|
+
if not self.api_token:
|
42
|
+
raise ValueError("Replicate API token not found in provider configuration")
|
43
|
+
|
44
|
+
# 设置 API token
|
45
|
+
os.environ["REPLICATE_API_TOKEN"] = self.api_token
|
46
|
+
|
47
|
+
# 生成图像存储目录
|
49
48
|
self.output_dir = "generated_images"
|
50
49
|
os.makedirs(self.output_dir, exist_ok=True)
|
51
50
|
|
52
|
-
#
|
53
|
-
|
54
|
-
|
55
|
-
|
51
|
+
# 统计信息
|
52
|
+
self.last_generation_count = 0
|
53
|
+
self.total_generation_count = 0
|
54
|
+
|
55
|
+
logger.info(f"Initialized ReplicateImageGenService with model '{self.model_name}'")
|
56
|
+
|
57
|
+
async def generate_image(
|
58
|
+
self,
|
59
|
+
prompt: str,
|
60
|
+
negative_prompt: Optional[str] = None,
|
61
|
+
width: int = 512,
|
62
|
+
height: int = 512,
|
63
|
+
num_inference_steps: int = 4,
|
64
|
+
guidance_scale: float = 7.5,
|
65
|
+
seed: Optional[int] = None
|
66
|
+
) -> Dict[str, Any]:
|
67
|
+
"""生成单张图像 (文生图)"""
|
68
|
+
|
69
|
+
if "flux-schnell" in self.model_name:
|
70
|
+
# FLUX Schnell 参数
|
71
|
+
input_data = {
|
72
|
+
"prompt": prompt,
|
73
|
+
"go_fast": True,
|
74
|
+
"megapixels": "1",
|
75
|
+
"num_outputs": 1,
|
76
|
+
"aspect_ratio": "1:1",
|
77
|
+
"output_format": "jpg",
|
78
|
+
"output_quality": 90,
|
79
|
+
"num_inference_steps": 4
|
80
|
+
}
|
56
81
|
else:
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
files_to_close = []
|
66
|
-
for key, value in prepared_input.items():
|
67
|
-
# 如果值是字符串,且看起来像一个存在的本地文件路径
|
68
|
-
if isinstance(value, str) and not value.startswith(('http://', 'https://')) and os.path.exists(value):
|
69
|
-
logger.info(f"检测到本地文件路径 '{value}',准备打开文件。")
|
70
|
-
try:
|
71
|
-
file_handle = open(value, "rb")
|
72
|
-
prepared_input[key] = file_handle
|
73
|
-
files_to_close.append(file_handle)
|
74
|
-
except Exception as e:
|
75
|
-
logger.error(f"打开文件失败 '{value}': {e}")
|
76
|
-
raise
|
77
|
-
return prepared_input, files_to_close
|
78
|
-
|
79
|
-
async def generate_image(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
80
|
-
"""
|
81
|
-
使用Replicate模型生成图像 (已优化为原生异步)
|
82
|
-
"""
|
83
|
-
prepared_input, files_to_close = await self._prepare_input_files(input_data)
|
84
|
-
try:
|
85
|
-
# 设置默认参数
|
86
|
-
if "guidance_scale" not in prepared_input:
|
87
|
-
prepared_input["guidance_scale"] = self.guidance_scale
|
88
|
-
if "num_inference_steps" not in prepared_input:
|
89
|
-
prepared_input["num_inference_steps"] = self.num_inference_steps
|
82
|
+
# 默认参数
|
83
|
+
input_data = {
|
84
|
+
"prompt": prompt,
|
85
|
+
"width": width,
|
86
|
+
"height": height,
|
87
|
+
"num_inference_steps": num_inference_steps,
|
88
|
+
"guidance_scale": guidance_scale
|
89
|
+
}
|
90
90
|
|
91
|
-
|
91
|
+
if negative_prompt:
|
92
|
+
input_data["negative_prompt"] = negative_prompt
|
93
|
+
if seed:
|
94
|
+
input_data["seed"] = seed
|
95
|
+
|
96
|
+
return await self._generate_internal(input_data)
|
97
|
+
|
98
|
+
async def image_to_image(
|
99
|
+
self,
|
100
|
+
prompt: str,
|
101
|
+
init_image: Union[str, Any],
|
102
|
+
strength: float = 0.8,
|
103
|
+
negative_prompt: Optional[str] = None,
|
104
|
+
num_inference_steps: int = 20,
|
105
|
+
guidance_scale: float = 7.5,
|
106
|
+
seed: Optional[int] = None
|
107
|
+
) -> Dict[str, Any]:
|
108
|
+
"""图生图"""
|
109
|
+
|
110
|
+
if "flux-kontext-pro" in self.model_name:
|
111
|
+
# FLUX Kontext Pro 参数
|
112
|
+
input_data = {
|
113
|
+
"prompt": prompt,
|
114
|
+
"input_image": init_image,
|
115
|
+
"aspect_ratio": "match_input_image",
|
116
|
+
"output_format": "jpg",
|
117
|
+
"safety_tolerance": 2
|
118
|
+
}
|
119
|
+
else:
|
120
|
+
# 默认参数
|
121
|
+
input_data = {
|
122
|
+
"prompt": prompt,
|
123
|
+
"image": init_image,
|
124
|
+
"strength": strength,
|
125
|
+
"num_inference_steps": num_inference_steps,
|
126
|
+
"guidance_scale": guidance_scale
|
127
|
+
}
|
92
128
|
|
93
|
-
|
94
|
-
|
129
|
+
if negative_prompt:
|
130
|
+
input_data["negative_prompt"] = negative_prompt
|
131
|
+
if seed:
|
132
|
+
input_data["seed"] = seed
|
133
|
+
|
134
|
+
return await self._generate_internal(input_data)
|
135
|
+
|
136
|
+
async def _generate_internal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
137
|
+
"""内部生成方法"""
|
138
|
+
try:
|
139
|
+
logger.info(f"开始使用模型 {self.model_name} 生成图像")
|
140
|
+
|
141
|
+
# 调用 Replicate API
|
142
|
+
output = await replicate.async_run(self.model_name, input=input_data)
|
95
143
|
|
96
|
-
#
|
144
|
+
# 处理输出
|
97
145
|
if isinstance(output, list):
|
98
146
|
urls = output
|
99
147
|
else:
|
100
148
|
urls = [output]
|
101
149
|
|
150
|
+
# 更新统计
|
151
|
+
self.last_generation_count = len(urls)
|
152
|
+
self.total_generation_count += len(urls)
|
153
|
+
|
154
|
+
# 计算成本
|
155
|
+
cost = self._calculate_cost(len(urls))
|
156
|
+
|
157
|
+
# 跟踪计费信息
|
158
|
+
from isa_model.inference.billing_tracker import ServiceType
|
159
|
+
self._track_usage(
|
160
|
+
service_type=ServiceType.IMAGE_GENERATION,
|
161
|
+
operation="image_generation",
|
162
|
+
input_units=len(urls), # 生成的图像数量
|
163
|
+
metadata={
|
164
|
+
"model": self.model_name,
|
165
|
+
"prompt": input_data.get("prompt", "")[:100], # 截取前100字符
|
166
|
+
"generation_type": "t2i" if "flux-schnell" in self.model_name else "i2i"
|
167
|
+
}
|
168
|
+
)
|
169
|
+
|
102
170
|
result = {
|
103
171
|
"urls": urls,
|
172
|
+
"count": len(urls),
|
173
|
+
"cost_usd": cost,
|
174
|
+
"model": self.model_name,
|
104
175
|
"metadata": {
|
105
|
-
"
|
106
|
-
"
|
176
|
+
"input": input_data,
|
177
|
+
"generation_count": len(urls)
|
107
178
|
}
|
108
179
|
}
|
109
|
-
|
180
|
+
|
181
|
+
logger.info(f"图像生成完成: {len(urls)} 张图像, 成本: ${cost:.6f}")
|
110
182
|
return result
|
183
|
+
|
111
184
|
except Exception as e:
|
112
185
|
logger.error(f"图像生成失败: {e}")
|
113
186
|
raise
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
187
|
+
|
188
|
+
def _calculate_cost(self, image_count: int) -> float:
|
189
|
+
"""计算生成成本"""
|
190
|
+
from isa_model.core.model_manager import ModelManager
|
191
|
+
|
192
|
+
manager = ModelManager()
|
193
|
+
|
194
|
+
if "flux-schnell" in self.model_name:
|
195
|
+
# $3 per 1000 images
|
196
|
+
return (image_count / 1000) * 3.0
|
197
|
+
elif "flux-kontext-pro" in self.model_name:
|
198
|
+
# $0.04 per image
|
199
|
+
return image_count * 0.04
|
200
|
+
else:
|
201
|
+
# 使用 ModelManager 的定价
|
202
|
+
pricing = manager.get_model_pricing("replicate", self.model_name)
|
203
|
+
return (image_count / 1000) * pricing.get("input", 0.0)
|
204
|
+
|
205
|
+
async def generate_images(
|
206
|
+
self,
|
207
|
+
prompt: str,
|
208
|
+
num_images: int = 1,
|
209
|
+
negative_prompt: Optional[str] = None,
|
210
|
+
width: int = 512,
|
211
|
+
height: int = 512,
|
212
|
+
num_inference_steps: int = 4,
|
213
|
+
guidance_scale: float = 7.5,
|
214
|
+
seed: Optional[int] = None
|
215
|
+
) -> List[Dict[str, Any]]:
|
216
|
+
"""生成多张图像"""
|
217
|
+
results = []
|
218
|
+
for i in range(num_images):
|
219
|
+
current_seed = seed + i if seed else None
|
220
|
+
result = await self.generate_image(
|
221
|
+
prompt, negative_prompt, width, height,
|
222
|
+
num_inference_steps, guidance_scale, current_seed
|
223
|
+
)
|
224
|
+
results.append(result)
|
225
|
+
return results
|
226
|
+
|
227
|
+
async def generate_image_to_file(
|
228
|
+
self,
|
229
|
+
prompt: str,
|
230
|
+
output_path: str,
|
231
|
+
negative_prompt: Optional[str] = None,
|
232
|
+
width: int = 512,
|
233
|
+
height: int = 512,
|
234
|
+
num_inference_steps: int = 4,
|
235
|
+
guidance_scale: float = 7.5,
|
236
|
+
seed: Optional[int] = None
|
237
|
+
) -> Dict[str, Any]:
|
238
|
+
"""生成图像并保存到文件"""
|
239
|
+
result = await self.generate_image(
|
240
|
+
prompt, negative_prompt, width, height,
|
241
|
+
num_inference_steps, guidance_scale, seed
|
242
|
+
)
|
243
|
+
|
244
|
+
# 保存第一张图像
|
245
|
+
if result.get("urls"):
|
246
|
+
url = result["urls"][0]
|
247
|
+
url_str = str(url) if hasattr(url, "__str__") else url
|
248
|
+
await self._download_image(url_str, output_path)
|
129
249
|
|
130
|
-
|
131
|
-
"
|
132
|
-
"
|
133
|
-
|
134
|
-
"input": input_data
|
135
|
-
}
|
250
|
+
return {
|
251
|
+
"file_path": output_path,
|
252
|
+
"cost_usd": result.get("cost_usd", 0.0),
|
253
|
+
"model": self.model_name
|
136
254
|
}
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
logger.error(f"图像分析失败: {e}")
|
141
|
-
raise
|
142
|
-
finally:
|
143
|
-
# ★ 新增: 确保所有打开的文件都被关闭
|
144
|
-
for f in files_to_close:
|
145
|
-
f.close()
|
146
|
-
|
147
|
-
async def generate_and_save(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
148
|
-
"""生成图像并保存到本地 (此方法无需修改)"""
|
149
|
-
result = await self.generate_image(input_data)
|
150
|
-
saved_paths = []
|
151
|
-
for i, url in enumerate(result["urls"]):
|
152
|
-
timestamp = int(time.time())
|
153
|
-
file_name = f"{self.output_dir}/{timestamp}_{uuid.uuid4().hex[:8]}_{i+1}.png"
|
154
|
-
try:
|
155
|
-
# Convert FileOutput object to string if necessary
|
156
|
-
url_str = str(url) if hasattr(url, "__str__") else url
|
157
|
-
await self._download_image(url_str, file_name)
|
158
|
-
saved_paths.append(file_name)
|
159
|
-
logger.info(f"图像已保存至: {file_name}")
|
160
|
-
except Exception as e:
|
161
|
-
logger.error(f"保存图像失败: {e}")
|
162
|
-
result["saved_paths"] = saved_paths
|
163
|
-
return result
|
164
|
-
|
255
|
+
else:
|
256
|
+
raise ValueError("No image generated")
|
257
|
+
|
165
258
|
async def _download_image(self, url: str, save_path: str) -> None:
|
166
|
-
"""
|
259
|
+
"""下载图像并保存"""
|
167
260
|
try:
|
168
261
|
async with aiohttp.ClientSession() as session:
|
169
262
|
async with session.get(url) as response:
|
@@ -175,11 +268,78 @@ class ReplicateVisionService(BaseService):
|
|
175
268
|
logger.error(f"下载图像时出错: {url}, {e}")
|
176
269
|
raise
|
177
270
|
|
178
|
-
|
271
|
+
def get_generation_stats(self) -> Dict[str, Any]:
|
272
|
+
"""获取生成统计信息"""
|
273
|
+
total_cost = 0.0
|
274
|
+
if "flux-schnell" in self.model_name:
|
275
|
+
total_cost = (self.total_generation_count / 1000) * 3.0
|
276
|
+
elif "flux-kontext-pro" in self.model_name:
|
277
|
+
total_cost = self.total_generation_count * 0.04
|
278
|
+
|
279
|
+
return {
|
280
|
+
"last_generation_count": self.last_generation_count,
|
281
|
+
"total_generation_count": self.total_generation_count,
|
282
|
+
"total_cost_usd": total_cost,
|
283
|
+
"model": self.model_name
|
284
|
+
}
|
285
|
+
|
286
|
+
def get_supported_sizes(self) -> List[Dict[str, int]]:
|
287
|
+
"""获取支持的图像尺寸"""
|
288
|
+
if "flux" in self.model_name:
|
289
|
+
return [
|
290
|
+
{"width": 512, "height": 512},
|
291
|
+
{"width": 768, "height": 768},
|
292
|
+
{"width": 1024, "height": 1024},
|
293
|
+
]
|
294
|
+
else:
|
295
|
+
return [
|
296
|
+
{"width": 512, "height": 512},
|
297
|
+
{"width": 768, "height": 768},
|
298
|
+
{"width": 1024, "height": 1024},
|
299
|
+
{"width": 768, "height": 1344},
|
300
|
+
{"width": 1344, "height": 768},
|
301
|
+
]
|
302
|
+
|
303
|
+
def get_model_info(self) -> Dict[str, Any]:
|
304
|
+
"""获取模型信息"""
|
305
|
+
if "flux-schnell" in self.model_name:
|
306
|
+
return {
|
307
|
+
"name": self.model_name,
|
308
|
+
"type": "t2i",
|
309
|
+
"cost_per_1000_images": 3.0,
|
310
|
+
"supports_negative_prompt": False,
|
311
|
+
"supports_img2img": False,
|
312
|
+
"max_steps": 4
|
313
|
+
}
|
314
|
+
elif "flux-kontext-pro" in self.model_name:
|
315
|
+
return {
|
316
|
+
"name": self.model_name,
|
317
|
+
"type": "i2i",
|
318
|
+
"cost_per_image": 0.04,
|
319
|
+
"supports_negative_prompt": False,
|
320
|
+
"supports_img2img": True,
|
321
|
+
"max_width": 1024,
|
322
|
+
"max_height": 1024
|
323
|
+
}
|
324
|
+
else:
|
325
|
+
return {
|
326
|
+
"name": self.model_name,
|
327
|
+
"type": "general",
|
328
|
+
"supports_negative_prompt": True,
|
329
|
+
"supports_img2img": True
|
330
|
+
}
|
331
|
+
|
179
332
|
async def load(self) -> None:
|
333
|
+
"""加载服务"""
|
180
334
|
if not self.api_token:
|
181
|
-
raise ValueError("缺少Replicate API
|
182
|
-
logger.info(f"Replicate
|
335
|
+
raise ValueError("缺少 Replicate API 令牌")
|
336
|
+
logger.info(f"Replicate 图像生成服务已准备就绪,使用模型: {self.model_name}")
|
183
337
|
|
184
338
|
async def unload(self) -> None:
|
185
|
-
|
339
|
+
"""卸载服务"""
|
340
|
+
logger.info(f"卸载 Replicate 图像生成服务: {self.model_name}")
|
341
|
+
|
342
|
+
async def close(self):
|
343
|
+
"""关闭服务"""
|
344
|
+
await self.unload()
|
345
|
+
|
isa_model/training/__init__.py
CHANGED
@@ -1,44 +1,74 @@
|
|
1
1
|
"""
|
2
|
-
ISA Model Training
|
2
|
+
ISA Model Training Module
|
3
3
|
|
4
|
-
|
5
|
-
-
|
6
|
-
-
|
7
|
-
- Model evaluation and
|
4
|
+
Provides unified training capabilities for AI models including:
|
5
|
+
- Local training with SFT (Supervised Fine-Tuning)
|
6
|
+
- Cloud training on RunPod
|
7
|
+
- Model evaluation and management
|
8
|
+
- HuggingFace integration
|
8
9
|
|
9
|
-
|
10
|
-
|
10
|
+
Example usage:
|
11
|
+
```python
|
12
|
+
from isa_model.training import TrainingFactory, train_gemma
|
11
13
|
|
12
|
-
#
|
13
|
-
|
14
|
+
# Quick Gemma training
|
15
|
+
model_path = train_gemma(
|
16
|
+
dataset_path="tatsu-lab/alpaca",
|
17
|
+
model_size="4b",
|
18
|
+
num_epochs=3
|
19
|
+
)
|
14
20
|
|
15
|
-
#
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
21
|
+
# Advanced training with custom configuration
|
22
|
+
factory = TrainingFactory()
|
23
|
+
model_path = factory.train_model(
|
24
|
+
model_name="google/gemma-2-4b-it",
|
25
|
+
dataset_path="your-dataset.json",
|
26
|
+
use_lora=True,
|
27
|
+
batch_size=4,
|
28
|
+
num_epochs=3
|
20
29
|
)
|
30
|
+
```
|
21
31
|
"""
|
22
32
|
|
23
|
-
|
24
|
-
from .
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
33
|
+
# Import the new clean factory
|
34
|
+
from .factory import TrainingFactory, train_gemma
|
35
|
+
|
36
|
+
# Import core components
|
37
|
+
from .core import (
|
38
|
+
TrainingConfig,
|
39
|
+
LoRAConfig,
|
40
|
+
DatasetConfig,
|
41
|
+
BaseTrainer,
|
42
|
+
SFTTrainer,
|
43
|
+
TrainingUtils,
|
44
|
+
DatasetManager
|
45
|
+
)
|
46
|
+
|
47
|
+
# Import cloud training components
|
48
|
+
from .cloud import (
|
49
|
+
RunPodConfig,
|
50
|
+
StorageConfig,
|
51
|
+
JobConfig,
|
52
|
+
TrainingJobOrchestrator
|
32
53
|
)
|
33
54
|
|
34
55
|
__all__ = [
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
56
|
+
# Main factory
|
57
|
+
'TrainingFactory',
|
58
|
+
'train_gemma',
|
59
|
+
|
60
|
+
# Core components
|
61
|
+
'TrainingConfig',
|
62
|
+
'LoRAConfig',
|
63
|
+
'DatasetConfig',
|
64
|
+
'BaseTrainer',
|
65
|
+
'SFTTrainer',
|
66
|
+
'TrainingUtils',
|
67
|
+
'DatasetManager',
|
68
|
+
|
69
|
+
# Cloud components
|
70
|
+
'RunPodConfig',
|
71
|
+
'StorageConfig',
|
72
|
+
'JobConfig',
|
73
|
+
'TrainingJobOrchestrator'
|
44
74
|
]
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
Cloud Training Module for ISA Model Framework
|
3
|
+
|
4
|
+
This module provides cloud training capabilities including:
|
5
|
+
- RunPod integration for on-demand GPU training
|
6
|
+
- Remote storage management (S3, GCS, etc.)
|
7
|
+
- Training job orchestration and monitoring
|
8
|
+
- Automatic resource scaling and management
|
9
|
+
"""
|
10
|
+
|
11
|
+
from .runpod_trainer import RunPodTrainer
|
12
|
+
from .storage_manager import CloudStorageManager
|
13
|
+
from .job_orchestrator import TrainingJobOrchestrator
|
14
|
+
|
15
|
+
# Import config classes - these are defined in each module that needs them
|
16
|
+
# from ..core.config import RunPodConfig, StorageConfig, JobConfig
|
17
|
+
|
18
|
+
__all__ = [
|
19
|
+
"RunPodTrainer",
|
20
|
+
"CloudStorageManager",
|
21
|
+
"TrainingJobOrchestrator"
|
22
|
+
]
|