isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.1.1.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,185 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Replicate Vision服务
|
6
|
+
用于与Replicate API交互,支持图像生成和图像分析
|
7
|
+
"""
|
8
|
+
|
9
|
+
import os
|
10
|
+
import time
|
11
|
+
import uuid
|
12
|
+
import logging
|
13
|
+
from typing import Dict, Any, List, Optional, Union, Tuple
|
14
|
+
import asyncio
|
15
|
+
import aiohttp
|
16
|
+
import replicate # 导入 replicate 库
|
17
|
+
from PIL import Image
|
18
|
+
from io import BytesIO
|
19
|
+
|
20
|
+
# 调整 BaseService 的导入路径以匹配您的项目结构
|
21
|
+
from isa_model.inference.services.base_service import BaseService
|
22
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
23
|
+
from isa_model.inference.base import ModelType
|
24
|
+
|
25
|
+
# 设置日志记录
|
26
|
+
logging.basicConfig(level=logging.INFO)
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
class ReplicateVisionService(BaseService):
|
30
|
+
"""
|
31
|
+
Replicate Vision服务,用于处理图像生成和分析。
|
32
|
+
经过调整,使用原生异步调用并优化了文件处理。
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, provider: BaseProvider, model_name: str):
|
36
|
+
"""
|
37
|
+
初始化Replicate Vision服务
|
38
|
+
"""
|
39
|
+
super().__init__(provider, model_name)
|
40
|
+
# 从 provider 或环境变量获取 API token
|
41
|
+
self.api_token = self.provider.config.get("api_token", os.environ.get("REPLICATE_API_TOKEN"))
|
42
|
+
self.model_type = ModelType.VISION
|
43
|
+
|
44
|
+
# 可选的默认配置
|
45
|
+
self.guidance_scale = self.provider.config.get("guidance_scale", 7.5)
|
46
|
+
self.num_inference_steps = self.provider.config.get("num_inference_steps", 30)
|
47
|
+
|
48
|
+
# 生成的图像存储目录
|
49
|
+
self.output_dir = "generated_images"
|
50
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
51
|
+
|
52
|
+
# ★ 调整点: 为 replicate 库设置 API token
|
53
|
+
if self.api_token:
|
54
|
+
# replicate 库会自动从环境变量读取,我们确保它被设置
|
55
|
+
os.environ["REPLICATE_API_TOKEN"] = self.api_token
|
56
|
+
else:
|
57
|
+
logger.warning("Replicate API token 未找到。服务可能无法正常工作。")
|
58
|
+
|
59
|
+
async def _prepare_input_files(self, input_data: Dict[str, Any]) -> Tuple[Dict[str, Any], List[Any]]:
|
60
|
+
"""
|
61
|
+
★ 新增辅助函数: 准备输入数据,将本地文件路径转换为文件对象。
|
62
|
+
这使得服务能统一处理本地文件和URL。
|
63
|
+
"""
|
64
|
+
prepared_input = input_data.copy()
|
65
|
+
files_to_close = []
|
66
|
+
for key, value in prepared_input.items():
|
67
|
+
# 如果值是字符串,且看起来像一个存在的本地文件路径
|
68
|
+
if isinstance(value, str) and not value.startswith(('http://', 'https://')) and os.path.exists(value):
|
69
|
+
logger.info(f"检测到本地文件路径 '{value}',准备打开文件。")
|
70
|
+
try:
|
71
|
+
file_handle = open(value, "rb")
|
72
|
+
prepared_input[key] = file_handle
|
73
|
+
files_to_close.append(file_handle)
|
74
|
+
except Exception as e:
|
75
|
+
logger.error(f"打开文件失败 '{value}': {e}")
|
76
|
+
raise
|
77
|
+
return prepared_input, files_to_close
|
78
|
+
|
79
|
+
async def generate_image(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
80
|
+
"""
|
81
|
+
使用Replicate模型生成图像 (已优化为原生异步)
|
82
|
+
"""
|
83
|
+
prepared_input, files_to_close = await self._prepare_input_files(input_data)
|
84
|
+
try:
|
85
|
+
# 设置默认参数
|
86
|
+
if "guidance_scale" not in prepared_input:
|
87
|
+
prepared_input["guidance_scale"] = self.guidance_scale
|
88
|
+
if "num_inference_steps" not in prepared_input:
|
89
|
+
prepared_input["num_inference_steps"] = self.num_inference_steps
|
90
|
+
|
91
|
+
logger.info(f"开始使用模型 {self.model_name} 生成图像 (原生异步)")
|
92
|
+
|
93
|
+
# ★ 调整点: 使用原生异步的 replicate.async_run
|
94
|
+
output = await replicate.async_run(self.model_name, input=prepared_input)
|
95
|
+
|
96
|
+
# 将结果转换为标准格式 (此部分逻辑无需改变)
|
97
|
+
if isinstance(output, list):
|
98
|
+
urls = output
|
99
|
+
else:
|
100
|
+
urls = [output]
|
101
|
+
|
102
|
+
result = {
|
103
|
+
"urls": urls,
|
104
|
+
"metadata": {
|
105
|
+
"model": self.model_name,
|
106
|
+
"input": input_data # 返回原始输入以供参考
|
107
|
+
}
|
108
|
+
}
|
109
|
+
logger.info(f"图像生成完成: {result['urls']}")
|
110
|
+
return result
|
111
|
+
except Exception as e:
|
112
|
+
logger.error(f"图像生成失败: {e}")
|
113
|
+
raise
|
114
|
+
finally:
|
115
|
+
# ★ 新增: 确保所有打开的文件都被关闭
|
116
|
+
for f in files_to_close:
|
117
|
+
f.close()
|
118
|
+
|
119
|
+
async def analyze_image(self, image_path: str, prompt: str) -> Dict[str, Any]:
|
120
|
+
"""
|
121
|
+
分析图像 (已优化为原生异步)
|
122
|
+
"""
|
123
|
+
input_data = {"image": image_path, "prompt": prompt}
|
124
|
+
prepared_input, files_to_close = await self._prepare_input_files(input_data)
|
125
|
+
try:
|
126
|
+
logger.info(f"开始使用模型 {self.model_name} 分析图像 (原生异步)")
|
127
|
+
# ★ 调整点: 使用原生异步的 replicate.async_run
|
128
|
+
output = await replicate.async_run(self.model_name, input=prepared_input)
|
129
|
+
|
130
|
+
result = {
|
131
|
+
"text": "".join(output) if isinstance(output, list) else output,
|
132
|
+
"metadata": {
|
133
|
+
"model": self.model_name,
|
134
|
+
"input": input_data
|
135
|
+
}
|
136
|
+
}
|
137
|
+
logger.info(f"图像分析完成")
|
138
|
+
return result
|
139
|
+
except Exception as e:
|
140
|
+
logger.error(f"图像分析失败: {e}")
|
141
|
+
raise
|
142
|
+
finally:
|
143
|
+
# ★ 新增: 确保所有打开的文件都被关闭
|
144
|
+
for f in files_to_close:
|
145
|
+
f.close()
|
146
|
+
|
147
|
+
async def generate_and_save(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
148
|
+
"""生成图像并保存到本地 (此方法无需修改)"""
|
149
|
+
result = await self.generate_image(input_data)
|
150
|
+
saved_paths = []
|
151
|
+
for i, url in enumerate(result["urls"]):
|
152
|
+
timestamp = int(time.time())
|
153
|
+
file_name = f"{self.output_dir}/{timestamp}_{uuid.uuid4().hex[:8]}_{i+1}.png"
|
154
|
+
try:
|
155
|
+
# Convert FileOutput object to string if necessary
|
156
|
+
url_str = str(url) if hasattr(url, "__str__") else url
|
157
|
+
await self._download_image(url_str, file_name)
|
158
|
+
saved_paths.append(file_name)
|
159
|
+
logger.info(f"图像已保存至: {file_name}")
|
160
|
+
except Exception as e:
|
161
|
+
logger.error(f"保存图像失败: {e}")
|
162
|
+
result["saved_paths"] = saved_paths
|
163
|
+
return result
|
164
|
+
|
165
|
+
async def _download_image(self, url: str, save_path: str) -> None:
|
166
|
+
"""异步下载图像并保存 (此方法无需修改)"""
|
167
|
+
try:
|
168
|
+
async with aiohttp.ClientSession() as session:
|
169
|
+
async with session.get(url) as response:
|
170
|
+
response.raise_for_status()
|
171
|
+
content = await response.read()
|
172
|
+
with Image.open(BytesIO(content)) as img:
|
173
|
+
img.save(save_path)
|
174
|
+
except Exception as e:
|
175
|
+
logger.error(f"下载图像时出错: {url}, {e}")
|
176
|
+
raise
|
177
|
+
|
178
|
+
# `load` 和 `unload` 方法在Replicate API场景下通常是轻量级的
|
179
|
+
async def load(self) -> None:
|
180
|
+
if not self.api_token:
|
181
|
+
raise ValueError("缺少Replicate API令牌,请设置REPLICATE_API_TOKEN环境变量或在provider配置中提供")
|
182
|
+
logger.info(f"Replicate Vision服务已准备就绪,使用模型: {self.model_name}")
|
183
|
+
|
184
|
+
async def unload(self) -> None:
|
185
|
+
logger.info(f"卸载Replicate Vision服务: {self.model_name}")
|
@@ -0,0 +1,44 @@
|
|
1
|
+
"""
|
2
|
+
ISA Model Training Framework
|
3
|
+
|
4
|
+
This module provides unified interfaces for training various types of AI models:
|
5
|
+
- LLM training with LlamaFactory
|
6
|
+
- Image model training with Flux/LoRA
|
7
|
+
- Model evaluation and benchmarking
|
8
|
+
|
9
|
+
Usage:
|
10
|
+
from isa_model.training import TrainingFactory
|
11
|
+
|
12
|
+
# Create training factory
|
13
|
+
factory = TrainingFactory()
|
14
|
+
|
15
|
+
# Fine-tune Gemma 3:4B
|
16
|
+
model_path = factory.finetune_llm(
|
17
|
+
model_name="gemma:4b",
|
18
|
+
dataset_path="path/to/data.json",
|
19
|
+
training_type="sft"
|
20
|
+
)
|
21
|
+
"""
|
22
|
+
|
23
|
+
from .factory import TrainingFactory, finetune_gemma
|
24
|
+
from .engine.llama_factory import (
|
25
|
+
LlamaFactory,
|
26
|
+
LlamaFactoryConfig,
|
27
|
+
SFTConfig,
|
28
|
+
RLConfig,
|
29
|
+
DPOConfig,
|
30
|
+
TrainingStrategy,
|
31
|
+
DatasetFormat
|
32
|
+
)
|
33
|
+
|
34
|
+
__all__ = [
|
35
|
+
"TrainingFactory",
|
36
|
+
"finetune_gemma",
|
37
|
+
"LlamaFactory",
|
38
|
+
"LlamaFactoryConfig",
|
39
|
+
"SFTConfig",
|
40
|
+
"RLConfig",
|
41
|
+
"DPOConfig",
|
42
|
+
"TrainingStrategy",
|
43
|
+
"DatasetFormat"
|
44
|
+
]
|
@@ -0,0 +1,393 @@
|
|
1
|
+
"""
|
2
|
+
Unified Training Factory for ISA Model Framework
|
3
|
+
|
4
|
+
This factory provides a single interface for all training operations:
|
5
|
+
- LLM fine-tuning (SFT, DPO, RLHF)
|
6
|
+
- Image model training (Flux, LoRA)
|
7
|
+
- Model evaluation and benchmarking
|
8
|
+
"""
|
9
|
+
|
10
|
+
import os
|
11
|
+
import logging
|
12
|
+
from typing import Optional, Dict, Any, Union, List
|
13
|
+
from pathlib import Path
|
14
|
+
import datetime
|
15
|
+
|
16
|
+
from .engine.llama_factory import LlamaFactory, TrainingStrategy, DatasetFormat
|
17
|
+
from .engine.llama_factory.config import SFTConfig, RLConfig, DPOConfig
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class TrainingFactory:
|
23
|
+
"""
|
24
|
+
Unified factory for all AI model training operations.
|
25
|
+
|
26
|
+
This class provides simplified interfaces for:
|
27
|
+
- LLM training using LlamaFactory
|
28
|
+
- Image model training using Flux/LoRA
|
29
|
+
- Model evaluation and benchmarking
|
30
|
+
|
31
|
+
Example usage for fine-tuning Gemma 3:4B:
|
32
|
+
```python
|
33
|
+
from isa_model.training import TrainingFactory
|
34
|
+
|
35
|
+
factory = TrainingFactory()
|
36
|
+
|
37
|
+
# Fine-tune with your dataset
|
38
|
+
model_path = factory.finetune_llm(
|
39
|
+
model_name="google/gemma-2-4b-it",
|
40
|
+
dataset_path="path/to/your/data.json",
|
41
|
+
training_type="sft",
|
42
|
+
use_lora=True,
|
43
|
+
num_epochs=3,
|
44
|
+
batch_size=4,
|
45
|
+
learning_rate=2e-5
|
46
|
+
)
|
47
|
+
|
48
|
+
# Train with DPO for preference optimization
|
49
|
+
dpo_model = factory.train_with_preferences(
|
50
|
+
model_path=model_path,
|
51
|
+
preference_data="path/to/preferences.json",
|
52
|
+
beta=0.1
|
53
|
+
)
|
54
|
+
```
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(self, base_output_dir: Optional[str] = None):
|
58
|
+
"""
|
59
|
+
Initialize the training factory.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
base_output_dir: Base directory for all training outputs
|
63
|
+
"""
|
64
|
+
self.base_output_dir = base_output_dir or os.path.join(os.getcwd(), "training_outputs")
|
65
|
+
os.makedirs(self.base_output_dir, exist_ok=True)
|
66
|
+
|
67
|
+
# Initialize sub-factories
|
68
|
+
self.llm_factory = LlamaFactory(base_output_dir=os.path.join(self.base_output_dir, "llm"))
|
69
|
+
|
70
|
+
logger.info(f"TrainingFactory initialized with output dir: {self.base_output_dir}")
|
71
|
+
|
72
|
+
def _get_output_dir(self, model_name: str, training_type: str) -> str:
|
73
|
+
"""Generate timestamped output directory."""
|
74
|
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
75
|
+
safe_model_name = model_name.replace("/", "_").replace(":", "_")
|
76
|
+
return os.path.join(self.base_output_dir, f"{safe_model_name}_{training_type}_{timestamp}")
|
77
|
+
|
78
|
+
# =================
|
79
|
+
# LLM Training Methods
|
80
|
+
# =================
|
81
|
+
|
82
|
+
def finetune_llm(
|
83
|
+
self,
|
84
|
+
model_name: str,
|
85
|
+
dataset_path: str,
|
86
|
+
training_type: str = "sft",
|
87
|
+
output_dir: Optional[str] = None,
|
88
|
+
dataset_format: str = "alpaca",
|
89
|
+
use_lora: bool = True,
|
90
|
+
batch_size: int = 4,
|
91
|
+
num_epochs: int = 3,
|
92
|
+
learning_rate: float = 2e-5,
|
93
|
+
max_length: int = 1024,
|
94
|
+
lora_rank: int = 8,
|
95
|
+
lora_alpha: int = 16,
|
96
|
+
val_dataset_path: Optional[str] = None,
|
97
|
+
**kwargs
|
98
|
+
) -> str:
|
99
|
+
"""
|
100
|
+
Fine-tune an LLM model.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
model_name: Model identifier (e.g., "google/gemma-2-4b-it", "meta-llama/Llama-2-7b-hf")
|
104
|
+
dataset_path: Path to training dataset
|
105
|
+
training_type: Type of training ("sft", "dpo", "rlhf")
|
106
|
+
output_dir: Custom output directory
|
107
|
+
dataset_format: Dataset format ("alpaca", "sharegpt", "custom")
|
108
|
+
use_lora: Whether to use LoRA for efficient training
|
109
|
+
batch_size: Training batch size
|
110
|
+
num_epochs: Number of training epochs
|
111
|
+
learning_rate: Learning rate
|
112
|
+
max_length: Maximum sequence length
|
113
|
+
lora_rank: LoRA rank parameter
|
114
|
+
lora_alpha: LoRA alpha parameter
|
115
|
+
val_dataset_path: Path to validation dataset (optional)
|
116
|
+
**kwargs: Additional training parameters
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
Path to the trained model
|
120
|
+
|
121
|
+
Example:
|
122
|
+
```python
|
123
|
+
# Fine-tune Gemma 3:4B with your dataset
|
124
|
+
model_path = factory.finetune_llm(
|
125
|
+
model_name="google/gemma-2-4b-it",
|
126
|
+
dataset_path="my_training_data.json",
|
127
|
+
training_type="sft",
|
128
|
+
use_lora=True,
|
129
|
+
num_epochs=3,
|
130
|
+
batch_size=4
|
131
|
+
)
|
132
|
+
```
|
133
|
+
"""
|
134
|
+
if not output_dir:
|
135
|
+
output_dir = self._get_output_dir(model_name, training_type)
|
136
|
+
|
137
|
+
# Convert format string to enum
|
138
|
+
format_map = {
|
139
|
+
"alpaca": DatasetFormat.ALPACA,
|
140
|
+
"sharegpt": DatasetFormat.SHAREGPT,
|
141
|
+
"custom": DatasetFormat.CUSTOM
|
142
|
+
}
|
143
|
+
dataset_format_enum = format_map.get(dataset_format, DatasetFormat.ALPACA)
|
144
|
+
|
145
|
+
if training_type.lower() == "sft":
|
146
|
+
return self.llm_factory.finetune(
|
147
|
+
model_path=model_name,
|
148
|
+
train_data=dataset_path,
|
149
|
+
val_data=val_dataset_path,
|
150
|
+
output_dir=output_dir,
|
151
|
+
dataset_format=dataset_format_enum,
|
152
|
+
use_lora=use_lora,
|
153
|
+
batch_size=batch_size,
|
154
|
+
num_epochs=num_epochs,
|
155
|
+
learning_rate=learning_rate,
|
156
|
+
max_length=max_length,
|
157
|
+
lora_rank=lora_rank,
|
158
|
+
lora_alpha=lora_alpha,
|
159
|
+
**kwargs
|
160
|
+
)
|
161
|
+
else:
|
162
|
+
raise ValueError(f"Training type '{training_type}' not supported yet. Use 'sft' for now.")
|
163
|
+
|
164
|
+
def train_with_preferences(
|
165
|
+
self,
|
166
|
+
model_path: str,
|
167
|
+
preference_data: str,
|
168
|
+
output_dir: Optional[str] = None,
|
169
|
+
reference_model: Optional[str] = None,
|
170
|
+
beta: float = 0.1,
|
171
|
+
use_lora: bool = True,
|
172
|
+
batch_size: int = 4,
|
173
|
+
num_epochs: int = 3,
|
174
|
+
learning_rate: float = 5e-6,
|
175
|
+
val_data: Optional[str] = None,
|
176
|
+
**kwargs
|
177
|
+
) -> str:
|
178
|
+
"""
|
179
|
+
Train model with preference data using DPO.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
model_path: Path to the base model
|
183
|
+
preference_data: Path to preference dataset
|
184
|
+
output_dir: Custom output directory
|
185
|
+
reference_model: Reference model for DPO (optional)
|
186
|
+
beta: DPO beta parameter
|
187
|
+
use_lora: Whether to use LoRA
|
188
|
+
batch_size: Training batch size
|
189
|
+
num_epochs: Number of epochs
|
190
|
+
learning_rate: Learning rate
|
191
|
+
val_data: Validation data path
|
192
|
+
**kwargs: Additional parameters
|
193
|
+
|
194
|
+
Returns:
|
195
|
+
Path to the trained model
|
196
|
+
"""
|
197
|
+
if not output_dir:
|
198
|
+
model_name = os.path.basename(model_path)
|
199
|
+
output_dir = self._get_output_dir(model_name, "dpo")
|
200
|
+
|
201
|
+
return self.llm_factory.dpo(
|
202
|
+
model_path=model_path,
|
203
|
+
train_data=preference_data,
|
204
|
+
val_data=val_data,
|
205
|
+
reference_model=reference_model,
|
206
|
+
output_dir=output_dir,
|
207
|
+
use_lora=use_lora,
|
208
|
+
batch_size=batch_size,
|
209
|
+
num_epochs=num_epochs,
|
210
|
+
learning_rate=learning_rate,
|
211
|
+
beta=beta,
|
212
|
+
**kwargs
|
213
|
+
)
|
214
|
+
|
215
|
+
def train_reward_model(
|
216
|
+
self,
|
217
|
+
model_path: str,
|
218
|
+
reward_data: str,
|
219
|
+
output_dir: Optional[str] = None,
|
220
|
+
use_lora: bool = True,
|
221
|
+
batch_size: int = 8,
|
222
|
+
num_epochs: int = 3,
|
223
|
+
learning_rate: float = 1e-5,
|
224
|
+
val_data: Optional[str] = None,
|
225
|
+
**kwargs
|
226
|
+
) -> str:
|
227
|
+
"""
|
228
|
+
Train a reward model for RLHF.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
model_path: Base model path
|
232
|
+
reward_data: Reward training data
|
233
|
+
output_dir: Output directory
|
234
|
+
use_lora: Whether to use LoRA
|
235
|
+
batch_size: Batch size
|
236
|
+
num_epochs: Number of epochs
|
237
|
+
learning_rate: Learning rate
|
238
|
+
val_data: Validation data
|
239
|
+
**kwargs: Additional parameters
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
Path to trained reward model
|
243
|
+
"""
|
244
|
+
if not output_dir:
|
245
|
+
model_name = os.path.basename(model_path)
|
246
|
+
output_dir = self._get_output_dir(model_name, "reward")
|
247
|
+
|
248
|
+
return self.llm_factory.train_reward_model(
|
249
|
+
model_path=model_path,
|
250
|
+
train_data=reward_data,
|
251
|
+
val_data=val_data,
|
252
|
+
output_dir=output_dir,
|
253
|
+
use_lora=use_lora,
|
254
|
+
batch_size=batch_size,
|
255
|
+
num_epochs=num_epochs,
|
256
|
+
learning_rate=learning_rate,
|
257
|
+
**kwargs
|
258
|
+
)
|
259
|
+
|
260
|
+
# =================
|
261
|
+
# Image Model Training Methods
|
262
|
+
# =================
|
263
|
+
|
264
|
+
def train_image_model(
|
265
|
+
self,
|
266
|
+
model_type: str = "flux",
|
267
|
+
training_images_dir: str = "",
|
268
|
+
output_dir: Optional[str] = None,
|
269
|
+
use_lora: bool = True,
|
270
|
+
num_epochs: int = 1000,
|
271
|
+
batch_size: int = 1,
|
272
|
+
learning_rate: float = 1e-4,
|
273
|
+
**kwargs
|
274
|
+
) -> str:
|
275
|
+
"""
|
276
|
+
Train an image generation model.
|
277
|
+
|
278
|
+
Args:
|
279
|
+
model_type: Type of model ("flux", "lora")
|
280
|
+
training_images_dir: Directory containing training images
|
281
|
+
output_dir: Output directory
|
282
|
+
use_lora: Whether to use LoRA
|
283
|
+
num_epochs: Training epochs
|
284
|
+
batch_size: Batch size
|
285
|
+
learning_rate: Learning rate
|
286
|
+
**kwargs: Additional parameters
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
Path to trained model
|
290
|
+
"""
|
291
|
+
if not output_dir:
|
292
|
+
output_dir = self._get_output_dir("image_model", model_type)
|
293
|
+
|
294
|
+
# TODO: Implement image model training
|
295
|
+
logger.warning("Image model training not fully implemented yet")
|
296
|
+
return output_dir
|
297
|
+
|
298
|
+
# =================
|
299
|
+
# Utility Methods
|
300
|
+
# =================
|
301
|
+
|
302
|
+
def get_training_status(self, output_dir: str) -> Dict[str, Any]:
|
303
|
+
"""
|
304
|
+
Get training status from output directory.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
output_dir: Training output directory
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
Dictionary with training status information
|
311
|
+
"""
|
312
|
+
status = {
|
313
|
+
"output_dir": output_dir,
|
314
|
+
"exists": os.path.exists(output_dir),
|
315
|
+
"files": []
|
316
|
+
}
|
317
|
+
|
318
|
+
if status["exists"]:
|
319
|
+
status["files"] = os.listdir(output_dir)
|
320
|
+
|
321
|
+
return status
|
322
|
+
|
323
|
+
def list_trained_models(self) -> List[Dict[str, Any]]:
|
324
|
+
"""
|
325
|
+
List all trained models in the output directory.
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
List of model information dictionaries
|
329
|
+
"""
|
330
|
+
models = []
|
331
|
+
|
332
|
+
if os.path.exists(self.base_output_dir):
|
333
|
+
for item in os.listdir(self.base_output_dir):
|
334
|
+
item_path = os.path.join(self.base_output_dir, item)
|
335
|
+
if os.path.isdir(item_path):
|
336
|
+
models.append({
|
337
|
+
"name": item,
|
338
|
+
"path": item_path,
|
339
|
+
"created": datetime.datetime.fromtimestamp(
|
340
|
+
os.path.getctime(item_path)
|
341
|
+
).isoformat()
|
342
|
+
})
|
343
|
+
|
344
|
+
return sorted(models, key=lambda x: x["created"], reverse=True)
|
345
|
+
|
346
|
+
|
347
|
+
# Convenience functions for quick access
|
348
|
+
def finetune_gemma(
|
349
|
+
dataset_path: str,
|
350
|
+
model_size: str = "4b",
|
351
|
+
output_dir: Optional[str] = None,
|
352
|
+
**kwargs
|
353
|
+
) -> str:
|
354
|
+
"""
|
355
|
+
Quick function to fine-tune Gemma models.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
dataset_path: Path to training dataset
|
359
|
+
model_size: Model size ("2b", "4b", "7b")
|
360
|
+
output_dir: Output directory
|
361
|
+
**kwargs: Additional training parameters
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
Path to fine-tuned model
|
365
|
+
|
366
|
+
Example:
|
367
|
+
```python
|
368
|
+
from isa_model.training import finetune_gemma
|
369
|
+
|
370
|
+
model_path = finetune_gemma(
|
371
|
+
dataset_path="my_data.json",
|
372
|
+
model_size="4b",
|
373
|
+
num_epochs=3,
|
374
|
+
batch_size=4
|
375
|
+
)
|
376
|
+
```
|
377
|
+
"""
|
378
|
+
factory = TrainingFactory()
|
379
|
+
|
380
|
+
model_map = {
|
381
|
+
"2b": "google/gemma-2-2b-it",
|
382
|
+
"4b": "google/gemma-2-4b-it",
|
383
|
+
"7b": "google/gemma-2-7b-it"
|
384
|
+
}
|
385
|
+
|
386
|
+
model_name = model_map.get(model_size, "google/gemma-2-4b-it")
|
387
|
+
|
388
|
+
return factory.finetune_llm(
|
389
|
+
model_name=model_name,
|
390
|
+
dataset_path=dataset_path,
|
391
|
+
output_dir=output_dir,
|
392
|
+
**kwargs
|
393
|
+
)
|