isa-model 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_manager.py +69 -4
  3. isa_model/core/storage/hf_storage.py +419 -0
  4. isa_model/deployment/__init__.py +52 -0
  5. isa_model/deployment/core/__init__.py +34 -0
  6. isa_model/deployment/core/deployment_config.py +356 -0
  7. isa_model/deployment/core/deployment_manager.py +549 -0
  8. isa_model/deployment/core/isa_deployment_service.py +401 -0
  9. isa_model/eval/factory.py +381 -140
  10. isa_model/inference/ai_factory.py +427 -236
  11. isa_model/inference/billing_tracker.py +406 -0
  12. isa_model/inference/providers/base_provider.py +51 -4
  13. isa_model/inference/providers/ml_provider.py +50 -0
  14. isa_model/inference/providers/ollama_provider.py +37 -18
  15. isa_model/inference/providers/openai_provider.py +65 -36
  16. isa_model/inference/providers/replicate_provider.py +42 -30
  17. isa_model/inference/services/audio/base_stt_service.py +21 -2
  18. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  19. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  20. isa_model/inference/services/audio/openai_tts_service.py +149 -9
  21. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  22. isa_model/inference/services/base_service.py +36 -1
  23. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  24. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  25. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  26. isa_model/inference/services/llm/__init__.py +2 -0
  27. isa_model/inference/services/llm/base_llm_service.py +158 -86
  28. isa_model/inference/services/llm/llm_adapter.py +414 -0
  29. isa_model/inference/services/llm/ollama_llm_service.py +252 -63
  30. isa_model/inference/services/llm/openai_llm_service.py +231 -93
  31. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  32. isa_model/inference/services/ml/base_ml_service.py +78 -0
  33. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  34. isa_model/inference/services/vision/__init__.py +3 -3
  35. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  36. isa_model/inference/services/vision/base_vision_service.py +177 -0
  37. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  38. isa_model/inference/services/vision/ollama_vision_service.py +151 -17
  39. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  40. isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
  41. isa_model/training/__init__.py +62 -32
  42. isa_model/training/cloud/__init__.py +22 -0
  43. isa_model/training/cloud/job_orchestrator.py +402 -0
  44. isa_model/training/cloud/runpod_trainer.py +454 -0
  45. isa_model/training/cloud/storage_manager.py +482 -0
  46. isa_model/training/core/__init__.py +23 -0
  47. isa_model/training/core/config.py +181 -0
  48. isa_model/training/core/dataset.py +222 -0
  49. isa_model/training/core/trainer.py +720 -0
  50. isa_model/training/core/utils.py +213 -0
  51. isa_model/training/factory.py +229 -198
  52. isa_model-0.3.1.dist-info/METADATA +465 -0
  53. isa_model-0.3.1.dist-info/RECORD +91 -0
  54. isa_model/core/model_router.py +0 -226
  55. isa_model/core/model_version.py +0 -0
  56. isa_model/core/resource_manager.py +0 -202
  57. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  58. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  59. isa_model/training/engine/llama_factory/__init__.py +0 -39
  60. isa_model/training/engine/llama_factory/config.py +0 -115
  61. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  62. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  63. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  64. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  65. isa_model/training/engine/llama_factory/factory.py +0 -331
  66. isa_model/training/engine/llama_factory/rl.py +0 -254
  67. isa_model/training/engine/llama_factory/trainer.py +0 -171
  68. isa_model/training/image_model/configs/create_config.py +0 -37
  69. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  70. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  71. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  72. isa_model/training/image_model/prepare_upload.py +0 -17
  73. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  74. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  75. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  76. isa_model/training/image_model/train/train.py +0 -42
  77. isa_model/training/image_model/train/train_flux.py +0 -41
  78. isa_model/training/image_model/train/train_lora.py +0 -57
  79. isa_model/training/image_model/train_main.py +0 -25
  80. isa_model-0.2.0.dist-info/METADATA +0 -327
  81. isa_model-0.2.0.dist-info/RECORD +0 -92
  82. isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
  83. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  84. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  85. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  86. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  87. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  88. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  89. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  90. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  91. {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
  92. {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -2,168 +2,261 @@
2
2
  # -*- coding: utf-8 -*-
3
3
 
4
4
  """
5
- Replicate Vision服务
6
- 用于与Replicate API交互,支持图像生成和图像分析
5
+ Replicate 图像生成服务
6
+ 支持 flux-schnell (文生图) 和 flux-kontext-pro (图生图) 模型
7
7
  """
8
8
 
9
9
  import os
10
10
  import time
11
11
  import uuid
12
12
  import logging
13
- from typing import Dict, Any, List, Optional, Union, Tuple
13
+ from typing import Dict, Any, List, Optional, Union
14
14
  import asyncio
15
15
  import aiohttp
16
- import replicate # 导入 replicate 库
16
+ import replicate
17
17
  from PIL import Image
18
18
  from io import BytesIO
19
19
 
20
- # 调整 BaseService 的导入路径以匹配您的项目结构
21
- from isa_model.inference.services.base_service import BaseService
20
+ from isa_model.inference.services.vision.base_image_gen_service import BaseImageGenService
22
21
  from isa_model.inference.providers.base_provider import BaseProvider
23
- from isa_model.inference.base import ModelType
24
22
 
25
23
  # 设置日志记录
26
24
  logging.basicConfig(level=logging.INFO)
27
25
  logger = logging.getLogger(__name__)
28
26
 
29
- class ReplicateVisionService(BaseService):
27
+ class ReplicateImageGenService(BaseImageGenService):
30
28
  """
31
- Replicate Vision服务,用于处理图像生成和分析。
32
- 经过调整,使用原生异步调用并优化了文件处理。
29
+ Replicate 图像生成服务
30
+ - flux-schnell: 文生图 (t2i) - $3 per 1000 images
31
+ - flux-kontext-pro: 图生图 (i2i) - $0.04 per image
33
32
  """
34
33
 
35
34
  def __init__(self, provider: BaseProvider, model_name: str):
36
- """
37
- 初始化Replicate Vision服务
38
- """
39
35
  super().__init__(provider, model_name)
40
- # 从 provider 或环境变量获取 API token
41
- self.api_token = self.provider.config.get("api_token", os.environ.get("REPLICATE_API_TOKEN"))
42
- self.model_type = ModelType.VISION
43
36
 
44
- # 可选的默认配置
45
- self.guidance_scale = self.provider.config.get("guidance_scale", 7.5)
46
- self.num_inference_steps = self.provider.config.get("num_inference_steps", 30)
37
+ # 获取配置
38
+ provider_config = provider.get_full_config()
39
+ self.api_token = provider_config.get("api_token") or provider_config.get("replicate_api_token")
47
40
 
48
- # 生成的图像存储目录
41
+ if not self.api_token:
42
+ raise ValueError("Replicate API token not found in provider configuration")
43
+
44
+ # 设置 API token
45
+ os.environ["REPLICATE_API_TOKEN"] = self.api_token
46
+
47
+ # 生成图像存储目录
49
48
  self.output_dir = "generated_images"
50
49
  os.makedirs(self.output_dir, exist_ok=True)
51
50
 
52
- # ★ 调整点: 为 replicate 库设置 API token
53
- if self.api_token:
54
- # replicate 库会自动从环境变量读取,我们确保它被设置
55
- os.environ["REPLICATE_API_TOKEN"] = self.api_token
51
+ # 统计信息
52
+ self.last_generation_count = 0
53
+ self.total_generation_count = 0
54
+
55
+ logger.info(f"Initialized ReplicateImageGenService with model '{self.model_name}'")
56
+
57
+ async def generate_image(
58
+ self,
59
+ prompt: str,
60
+ negative_prompt: Optional[str] = None,
61
+ width: int = 512,
62
+ height: int = 512,
63
+ num_inference_steps: int = 4,
64
+ guidance_scale: float = 7.5,
65
+ seed: Optional[int] = None
66
+ ) -> Dict[str, Any]:
67
+ """生成单张图像 (文生图)"""
68
+
69
+ if "flux-schnell" in self.model_name:
70
+ # FLUX Schnell 参数
71
+ input_data = {
72
+ "prompt": prompt,
73
+ "go_fast": True,
74
+ "megapixels": "1",
75
+ "num_outputs": 1,
76
+ "aspect_ratio": "1:1",
77
+ "output_format": "jpg",
78
+ "output_quality": 90,
79
+ "num_inference_steps": 4
80
+ }
56
81
  else:
57
- logger.warning("Replicate API token 未找到。服务可能无法正常工作。")
58
-
59
- async def _prepare_input_files(self, input_data: Dict[str, Any]) -> Tuple[Dict[str, Any], List[Any]]:
60
- """
61
- 新增辅助函数: 准备输入数据,将本地文件路径转换为文件对象。
62
- 这使得服务能统一处理本地文件和URL。
63
- """
64
- prepared_input = input_data.copy()
65
- files_to_close = []
66
- for key, value in prepared_input.items():
67
- # 如果值是字符串,且看起来像一个存在的本地文件路径
68
- if isinstance(value, str) and not value.startswith(('http://', 'https://')) and os.path.exists(value):
69
- logger.info(f"检测到本地文件路径 '{value}',准备打开文件。")
70
- try:
71
- file_handle = open(value, "rb")
72
- prepared_input[key] = file_handle
73
- files_to_close.append(file_handle)
74
- except Exception as e:
75
- logger.error(f"打开文件失败 '{value}': {e}")
76
- raise
77
- return prepared_input, files_to_close
78
-
79
- async def generate_image(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
80
- """
81
- 使用Replicate模型生成图像 (已优化为原生异步)
82
- """
83
- prepared_input, files_to_close = await self._prepare_input_files(input_data)
84
- try:
85
- # 设置默认参数
86
- if "guidance_scale" not in prepared_input:
87
- prepared_input["guidance_scale"] = self.guidance_scale
88
- if "num_inference_steps" not in prepared_input:
89
- prepared_input["num_inference_steps"] = self.num_inference_steps
82
+ # 默认参数
83
+ input_data = {
84
+ "prompt": prompt,
85
+ "width": width,
86
+ "height": height,
87
+ "num_inference_steps": num_inference_steps,
88
+ "guidance_scale": guidance_scale
89
+ }
90
90
 
91
- logger.info(f"开始使用模型 {self.model_name} 生成图像 (原生异步)")
91
+ if negative_prompt:
92
+ input_data["negative_prompt"] = negative_prompt
93
+ if seed:
94
+ input_data["seed"] = seed
95
+
96
+ return await self._generate_internal(input_data)
97
+
98
+ async def image_to_image(
99
+ self,
100
+ prompt: str,
101
+ init_image: Union[str, Any],
102
+ strength: float = 0.8,
103
+ negative_prompt: Optional[str] = None,
104
+ num_inference_steps: int = 20,
105
+ guidance_scale: float = 7.5,
106
+ seed: Optional[int] = None
107
+ ) -> Dict[str, Any]:
108
+ """图生图"""
109
+
110
+ if "flux-kontext-pro" in self.model_name:
111
+ # FLUX Kontext Pro 参数
112
+ input_data = {
113
+ "prompt": prompt,
114
+ "input_image": init_image,
115
+ "aspect_ratio": "match_input_image",
116
+ "output_format": "jpg",
117
+ "safety_tolerance": 2
118
+ }
119
+ else:
120
+ # 默认参数
121
+ input_data = {
122
+ "prompt": prompt,
123
+ "image": init_image,
124
+ "strength": strength,
125
+ "num_inference_steps": num_inference_steps,
126
+ "guidance_scale": guidance_scale
127
+ }
92
128
 
93
- # ★ 调整点: 使用原生异步的 replicate.async_run
94
- output = await replicate.async_run(self.model_name, input=prepared_input)
129
+ if negative_prompt:
130
+ input_data["negative_prompt"] = negative_prompt
131
+ if seed:
132
+ input_data["seed"] = seed
133
+
134
+ return await self._generate_internal(input_data)
135
+
136
+ async def _generate_internal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
137
+ """内部生成方法"""
138
+ try:
139
+ logger.info(f"开始使用模型 {self.model_name} 生成图像")
140
+
141
+ # 调用 Replicate API
142
+ output = await replicate.async_run(self.model_name, input=input_data)
95
143
 
96
- # 将结果转换为标准格式 (此部分逻辑无需改变)
144
+ # 处理输出
97
145
  if isinstance(output, list):
98
146
  urls = output
99
147
  else:
100
148
  urls = [output]
101
149
 
150
+ # 更新统计
151
+ self.last_generation_count = len(urls)
152
+ self.total_generation_count += len(urls)
153
+
154
+ # 计算成本
155
+ cost = self._calculate_cost(len(urls))
156
+
157
+ # 跟踪计费信息
158
+ from isa_model.inference.billing_tracker import ServiceType
159
+ self._track_usage(
160
+ service_type=ServiceType.IMAGE_GENERATION,
161
+ operation="image_generation",
162
+ input_units=len(urls), # 生成的图像数量
163
+ metadata={
164
+ "model": self.model_name,
165
+ "prompt": input_data.get("prompt", "")[:100], # 截取前100字符
166
+ "generation_type": "t2i" if "flux-schnell" in self.model_name else "i2i"
167
+ }
168
+ )
169
+
102
170
  result = {
103
171
  "urls": urls,
172
+ "count": len(urls),
173
+ "cost_usd": cost,
174
+ "model": self.model_name,
104
175
  "metadata": {
105
- "model": self.model_name,
106
- "input": input_data # 返回原始输入以供参考
176
+ "input": input_data,
177
+ "generation_count": len(urls)
107
178
  }
108
179
  }
109
- logger.info(f"图像生成完成: {result['urls']}")
180
+
181
+ logger.info(f"图像生成完成: {len(urls)} 张图像, 成本: ${cost:.6f}")
110
182
  return result
183
+
111
184
  except Exception as e:
112
185
  logger.error(f"图像生成失败: {e}")
113
186
  raise
114
- finally:
115
- # 新增: 确保所有打开的文件都被关闭
116
- for f in files_to_close:
117
- f.close()
118
-
119
- async def analyze_image(self, image_path: str, prompt: str) -> Dict[str, Any]:
120
- """
121
- 分析图像 (已优化为原生异步)
122
- """
123
- input_data = {"image": image_path, "prompt": prompt}
124
- prepared_input, files_to_close = await self._prepare_input_files(input_data)
125
- try:
126
- logger.info(f"开始使用模型 {self.model_name} 分析图像 (原生异步)")
127
- # ★ 调整点: 使用原生异步的 replicate.async_run
128
- output = await replicate.async_run(self.model_name, input=prepared_input)
187
+
188
+ def _calculate_cost(self, image_count: int) -> float:
189
+ """计算生成成本"""
190
+ from isa_model.core.model_manager import ModelManager
191
+
192
+ manager = ModelManager()
193
+
194
+ if "flux-schnell" in self.model_name:
195
+ # $3 per 1000 images
196
+ return (image_count / 1000) * 3.0
197
+ elif "flux-kontext-pro" in self.model_name:
198
+ # $0.04 per image
199
+ return image_count * 0.04
200
+ else:
201
+ # 使用 ModelManager 的定价
202
+ pricing = manager.get_model_pricing("replicate", self.model_name)
203
+ return (image_count / 1000) * pricing.get("input", 0.0)
204
+
205
+ async def generate_images(
206
+ self,
207
+ prompt: str,
208
+ num_images: int = 1,
209
+ negative_prompt: Optional[str] = None,
210
+ width: int = 512,
211
+ height: int = 512,
212
+ num_inference_steps: int = 4,
213
+ guidance_scale: float = 7.5,
214
+ seed: Optional[int] = None
215
+ ) -> List[Dict[str, Any]]:
216
+ """生成多张图像"""
217
+ results = []
218
+ for i in range(num_images):
219
+ current_seed = seed + i if seed else None
220
+ result = await self.generate_image(
221
+ prompt, negative_prompt, width, height,
222
+ num_inference_steps, guidance_scale, current_seed
223
+ )
224
+ results.append(result)
225
+ return results
226
+
227
+ async def generate_image_to_file(
228
+ self,
229
+ prompt: str,
230
+ output_path: str,
231
+ negative_prompt: Optional[str] = None,
232
+ width: int = 512,
233
+ height: int = 512,
234
+ num_inference_steps: int = 4,
235
+ guidance_scale: float = 7.5,
236
+ seed: Optional[int] = None
237
+ ) -> Dict[str, Any]:
238
+ """生成图像并保存到文件"""
239
+ result = await self.generate_image(
240
+ prompt, negative_prompt, width, height,
241
+ num_inference_steps, guidance_scale, seed
242
+ )
243
+
244
+ # 保存第一张图像
245
+ if result.get("urls"):
246
+ url = result["urls"][0]
247
+ url_str = str(url) if hasattr(url, "__str__") else url
248
+ await self._download_image(url_str, output_path)
129
249
 
130
- result = {
131
- "text": "".join(output) if isinstance(output, list) else output,
132
- "metadata": {
133
- "model": self.model_name,
134
- "input": input_data
135
- }
250
+ return {
251
+ "file_path": output_path,
252
+ "cost_usd": result.get("cost_usd", 0.0),
253
+ "model": self.model_name
136
254
  }
137
- logger.info(f"图像分析完成")
138
- return result
139
- except Exception as e:
140
- logger.error(f"图像分析失败: {e}")
141
- raise
142
- finally:
143
- # ★ 新增: 确保所有打开的文件都被关闭
144
- for f in files_to_close:
145
- f.close()
146
-
147
- async def generate_and_save(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
148
- """生成图像并保存到本地 (此方法无需修改)"""
149
- result = await self.generate_image(input_data)
150
- saved_paths = []
151
- for i, url in enumerate(result["urls"]):
152
- timestamp = int(time.time())
153
- file_name = f"{self.output_dir}/{timestamp}_{uuid.uuid4().hex[:8]}_{i+1}.png"
154
- try:
155
- # Convert FileOutput object to string if necessary
156
- url_str = str(url) if hasattr(url, "__str__") else url
157
- await self._download_image(url_str, file_name)
158
- saved_paths.append(file_name)
159
- logger.info(f"图像已保存至: {file_name}")
160
- except Exception as e:
161
- logger.error(f"保存图像失败: {e}")
162
- result["saved_paths"] = saved_paths
163
- return result
164
-
255
+ else:
256
+ raise ValueError("No image generated")
257
+
165
258
  async def _download_image(self, url: str, save_path: str) -> None:
166
- """异步下载图像并保存 (此方法无需修改)"""
259
+ """下载图像并保存"""
167
260
  try:
168
261
  async with aiohttp.ClientSession() as session:
169
262
  async with session.get(url) as response:
@@ -175,11 +268,78 @@ class ReplicateVisionService(BaseService):
175
268
  logger.error(f"下载图像时出错: {url}, {e}")
176
269
  raise
177
270
 
178
- # `load` `unload` 方法在Replicate API场景下通常是轻量级的
271
+ def get_generation_stats(self) -> Dict[str, Any]:
272
+ """获取生成统计信息"""
273
+ total_cost = 0.0
274
+ if "flux-schnell" in self.model_name:
275
+ total_cost = (self.total_generation_count / 1000) * 3.0
276
+ elif "flux-kontext-pro" in self.model_name:
277
+ total_cost = self.total_generation_count * 0.04
278
+
279
+ return {
280
+ "last_generation_count": self.last_generation_count,
281
+ "total_generation_count": self.total_generation_count,
282
+ "total_cost_usd": total_cost,
283
+ "model": self.model_name
284
+ }
285
+
286
+ def get_supported_sizes(self) -> List[Dict[str, int]]:
287
+ """获取支持的图像尺寸"""
288
+ if "flux" in self.model_name:
289
+ return [
290
+ {"width": 512, "height": 512},
291
+ {"width": 768, "height": 768},
292
+ {"width": 1024, "height": 1024},
293
+ ]
294
+ else:
295
+ return [
296
+ {"width": 512, "height": 512},
297
+ {"width": 768, "height": 768},
298
+ {"width": 1024, "height": 1024},
299
+ {"width": 768, "height": 1344},
300
+ {"width": 1344, "height": 768},
301
+ ]
302
+
303
+ def get_model_info(self) -> Dict[str, Any]:
304
+ """获取模型信息"""
305
+ if "flux-schnell" in self.model_name:
306
+ return {
307
+ "name": self.model_name,
308
+ "type": "t2i",
309
+ "cost_per_1000_images": 3.0,
310
+ "supports_negative_prompt": False,
311
+ "supports_img2img": False,
312
+ "max_steps": 4
313
+ }
314
+ elif "flux-kontext-pro" in self.model_name:
315
+ return {
316
+ "name": self.model_name,
317
+ "type": "i2i",
318
+ "cost_per_image": 0.04,
319
+ "supports_negative_prompt": False,
320
+ "supports_img2img": True,
321
+ "max_width": 1024,
322
+ "max_height": 1024
323
+ }
324
+ else:
325
+ return {
326
+ "name": self.model_name,
327
+ "type": "general",
328
+ "supports_negative_prompt": True,
329
+ "supports_img2img": True
330
+ }
331
+
179
332
  async def load(self) -> None:
333
+ """加载服务"""
180
334
  if not self.api_token:
181
- raise ValueError("缺少Replicate API令牌,请设置REPLICATE_API_TOKEN环境变量或在provider配置中提供")
182
- logger.info(f"Replicate Vision服务已准备就绪,使用模型: {self.model_name}")
335
+ raise ValueError("缺少 Replicate API 令牌")
336
+ logger.info(f"Replicate 图像生成服务已准备就绪,使用模型: {self.model_name}")
183
337
 
184
338
  async def unload(self) -> None:
185
- logger.info(f"卸载Replicate Vision服务: {self.model_name}")
339
+ """卸载服务"""
340
+ logger.info(f"卸载 Replicate 图像生成服务: {self.model_name}")
341
+
342
+ async def close(self):
343
+ """关闭服务"""
344
+ await self.unload()
345
+
@@ -1,44 +1,74 @@
1
1
  """
2
- ISA Model Training Framework
2
+ ISA Model Training Module
3
3
 
4
- This module provides unified interfaces for training various types of AI models:
5
- - LLM training with LlamaFactory
6
- - Image model training with Flux/LoRA
7
- - Model evaluation and benchmarking
4
+ Provides unified training capabilities for AI models including:
5
+ - Local training with SFT (Supervised Fine-Tuning)
6
+ - Cloud training on RunPod
7
+ - Model evaluation and management
8
+ - HuggingFace integration
8
9
 
9
- Usage:
10
- from isa_model.training import TrainingFactory
10
+ Example usage:
11
+ ```python
12
+ from isa_model.training import TrainingFactory, train_gemma
11
13
 
12
- # Create training factory
13
- factory = TrainingFactory()
14
+ # Quick Gemma training
15
+ model_path = train_gemma(
16
+ dataset_path="tatsu-lab/alpaca",
17
+ model_size="4b",
18
+ num_epochs=3
19
+ )
14
20
 
15
- # Fine-tune Gemma 3:4B
16
- model_path = factory.finetune_llm(
17
- model_name="gemma:4b",
18
- dataset_path="path/to/data.json",
19
- training_type="sft"
21
+ # Advanced training with custom configuration
22
+ factory = TrainingFactory()
23
+ model_path = factory.train_model(
24
+ model_name="google/gemma-2-4b-it",
25
+ dataset_path="your-dataset.json",
26
+ use_lora=True,
27
+ batch_size=4,
28
+ num_epochs=3
20
29
  )
30
+ ```
21
31
  """
22
32
 
23
- from .factory import TrainingFactory, finetune_gemma
24
- from .engine.llama_factory import (
25
- LlamaFactory,
26
- LlamaFactoryConfig,
27
- SFTConfig,
28
- RLConfig,
29
- DPOConfig,
30
- TrainingStrategy,
31
- DatasetFormat
33
+ # Import the new clean factory
34
+ from .factory import TrainingFactory, train_gemma
35
+
36
+ # Import core components
37
+ from .core import (
38
+ TrainingConfig,
39
+ LoRAConfig,
40
+ DatasetConfig,
41
+ BaseTrainer,
42
+ SFTTrainer,
43
+ TrainingUtils,
44
+ DatasetManager
45
+ )
46
+
47
+ # Import cloud training components
48
+ from .cloud import (
49
+ RunPodConfig,
50
+ StorageConfig,
51
+ JobConfig,
52
+ TrainingJobOrchestrator
32
53
  )
33
54
 
34
55
  __all__ = [
35
- "TrainingFactory",
36
- "finetune_gemma",
37
- "LlamaFactory",
38
- "LlamaFactoryConfig",
39
- "SFTConfig",
40
- "RLConfig",
41
- "DPOConfig",
42
- "TrainingStrategy",
43
- "DatasetFormat"
56
+ # Main factory
57
+ 'TrainingFactory',
58
+ 'train_gemma',
59
+
60
+ # Core components
61
+ 'TrainingConfig',
62
+ 'LoRAConfig',
63
+ 'DatasetConfig',
64
+ 'BaseTrainer',
65
+ 'SFTTrainer',
66
+ 'TrainingUtils',
67
+ 'DatasetManager',
68
+
69
+ # Cloud components
70
+ 'RunPodConfig',
71
+ 'StorageConfig',
72
+ 'JobConfig',
73
+ 'TrainingJobOrchestrator'
44
74
  ]
@@ -0,0 +1,22 @@
1
+ """
2
+ Cloud Training Module for ISA Model Framework
3
+
4
+ This module provides cloud training capabilities including:
5
+ - RunPod integration for on-demand GPU training
6
+ - Remote storage management (S3, GCS, etc.)
7
+ - Training job orchestration and monitoring
8
+ - Automatic resource scaling and management
9
+ """
10
+
11
+ from .runpod_trainer import RunPodTrainer
12
+ from .storage_manager import CloudStorageManager
13
+ from .job_orchestrator import TrainingJobOrchestrator
14
+
15
+ # Import config classes - these are defined in each module that needs them
16
+ # from ..core.config import RunPodConfig, StorageConfig, JobConfig
17
+
18
+ __all__ = [
19
+ "RunPodTrainer",
20
+ "CloudStorageManager",
21
+ "TrainingJobOrchestrator"
22
+ ]