isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_registry.py +273 -46
  3. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  4. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  5. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  6. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  7. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  8. isa_model/eval/__init__.py +56 -0
  9. isa_model/eval/benchmarks.py +469 -0
  10. isa_model/eval/factory.py +582 -0
  11. isa_model/eval/metrics.py +628 -0
  12. isa_model/inference/ai_factory.py +98 -93
  13. isa_model/inference/providers/openai_provider.py +21 -7
  14. isa_model/inference/providers/replicate_provider.py +18 -5
  15. isa_model/inference/providers/triton_provider.py +1 -1
  16. isa_model/inference/services/audio/base_stt_service.py +91 -0
  17. isa_model/inference/services/audio/base_tts_service.py +136 -0
  18. isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
  19. isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
  20. isa_model/inference/services/llm/__init__.py +0 -4
  21. isa_model/inference/services/llm/base_llm_service.py +134 -0
  22. isa_model/inference/services/llm/ollama_llm_service.py +1 -10
  23. isa_model/inference/services/llm/openai_llm_service.py +70 -61
  24. isa_model/inference/services/vision/__init__.py +1 -1
  25. isa_model/inference/services/vision/ollama_vision_service.py +4 -4
  26. isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
  27. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  28. isa_model/training/__init__.py +44 -0
  29. isa_model/training/factory.py +393 -0
  30. isa_model-0.1.1.dist-info/METADATA +327 -0
  31. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
  32. isa_model/deployment/mlflow_gateway/__init__.py +0 -8
  33. isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
  34. isa_model/deployment/unified_multimodal_client.py +0 -341
  35. isa_model/inference/adapter/triton_adapter.py +0 -453
  36. isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
  37. isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
  38. isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
  39. isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
  40. isa_model/inference/backends/__init__.py +0 -53
  41. isa_model/inference/backends/base_backend_client.py +0 -26
  42. isa_model/inference/backends/container_services.py +0 -104
  43. isa_model/inference/backends/local_services.py +0 -72
  44. isa_model/inference/backends/openai_client.py +0 -130
  45. isa_model/inference/backends/replicate_client.py +0 -197
  46. isa_model/inference/backends/third_party_services.py +0 -239
  47. isa_model/inference/backends/triton_client.py +0 -97
  48. isa_model/inference/client_sdk/client.py +0 -134
  49. isa_model/inference/client_sdk/client_data_std.py +0 -34
  50. isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
  51. isa_model/inference/client_sdk/exceptions.py +0 -0
  52. isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
  53. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
  54. isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
  55. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
  56. isa_model/inference/providers/vllm_provider.py +0 -0
  57. isa_model/inference/providers/yyds_provider.py +0 -83
  58. isa_model/inference/services/audio/fish_speech/handler.py +0 -215
  59. isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
  60. isa_model/inference/services/audio/triton_speech_service.py +0 -138
  61. isa_model/inference/services/audio/whisper_service.py +0 -186
  62. isa_model/inference/services/base_tts_service.py +0 -66
  63. isa_model/inference/services/embedding/bge_service.py +0 -183
  64. isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
  65. isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
  66. isa_model/inference/services/llm/gemma_service.py +0 -143
  67. isa_model/inference/services/llm/llama_service.py +0 -143
  68. isa_model/inference/services/llm/replicate_llm_service.py +0 -179
  69. isa_model/inference/services/llm/triton_llm_service.py +0 -230
  70. isa_model/inference/services/vision/replicate_vision_service.py +0 -241
  71. isa_model/inference/services/vision/triton_vision_service.py +0 -199
  72. isa_model-0.1.0.dist-info/METADATA +0 -116
  73. /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
  74. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
  75. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
  76. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,185 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Replicate Vision服务
6
+ 用于与Replicate API交互,支持图像生成和图像分析
7
+ """
8
+
9
+ import os
10
+ import time
11
+ import uuid
12
+ import logging
13
+ from typing import Dict, Any, List, Optional, Union, Tuple
14
+ import asyncio
15
+ import aiohttp
16
+ import replicate # 导入 replicate 库
17
+ from PIL import Image
18
+ from io import BytesIO
19
+
20
+ # 调整 BaseService 的导入路径以匹配您的项目结构
21
+ from isa_model.inference.services.base_service import BaseService
22
+ from isa_model.inference.providers.base_provider import BaseProvider
23
+ from isa_model.inference.base import ModelType
24
+
25
+ # 设置日志记录
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ class ReplicateVisionService(BaseService):
30
+ """
31
+ Replicate Vision服务,用于处理图像生成和分析。
32
+ 经过调整,使用原生异步调用并优化了文件处理。
33
+ """
34
+
35
+ def __init__(self, provider: BaseProvider, model_name: str):
36
+ """
37
+ 初始化Replicate Vision服务
38
+ """
39
+ super().__init__(provider, model_name)
40
+ # 从 provider 或环境变量获取 API token
41
+ self.api_token = self.provider.config.get("api_token", os.environ.get("REPLICATE_API_TOKEN"))
42
+ self.model_type = ModelType.VISION
43
+
44
+ # 可选的默认配置
45
+ self.guidance_scale = self.provider.config.get("guidance_scale", 7.5)
46
+ self.num_inference_steps = self.provider.config.get("num_inference_steps", 30)
47
+
48
+ # 生成的图像存储目录
49
+ self.output_dir = "generated_images"
50
+ os.makedirs(self.output_dir, exist_ok=True)
51
+
52
+ # ★ 调整点: 为 replicate 库设置 API token
53
+ if self.api_token:
54
+ # replicate 库会自动从环境变量读取,我们确保它被设置
55
+ os.environ["REPLICATE_API_TOKEN"] = self.api_token
56
+ else:
57
+ logger.warning("Replicate API token 未找到。服务可能无法正常工作。")
58
+
59
+ async def _prepare_input_files(self, input_data: Dict[str, Any]) -> Tuple[Dict[str, Any], List[Any]]:
60
+ """
61
+ ★ 新增辅助函数: 准备输入数据,将本地文件路径转换为文件对象。
62
+ 这使得服务能统一处理本地文件和URL。
63
+ """
64
+ prepared_input = input_data.copy()
65
+ files_to_close = []
66
+ for key, value in prepared_input.items():
67
+ # 如果值是字符串,且看起来像一个存在的本地文件路径
68
+ if isinstance(value, str) and not value.startswith(('http://', 'https://')) and os.path.exists(value):
69
+ logger.info(f"检测到本地文件路径 '{value}',准备打开文件。")
70
+ try:
71
+ file_handle = open(value, "rb")
72
+ prepared_input[key] = file_handle
73
+ files_to_close.append(file_handle)
74
+ except Exception as e:
75
+ logger.error(f"打开文件失败 '{value}': {e}")
76
+ raise
77
+ return prepared_input, files_to_close
78
+
79
+ async def generate_image(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
80
+ """
81
+ 使用Replicate模型生成图像 (已优化为原生异步)
82
+ """
83
+ prepared_input, files_to_close = await self._prepare_input_files(input_data)
84
+ try:
85
+ # 设置默认参数
86
+ if "guidance_scale" not in prepared_input:
87
+ prepared_input["guidance_scale"] = self.guidance_scale
88
+ if "num_inference_steps" not in prepared_input:
89
+ prepared_input["num_inference_steps"] = self.num_inference_steps
90
+
91
+ logger.info(f"开始使用模型 {self.model_name} 生成图像 (原生异步)")
92
+
93
+ # ★ 调整点: 使用原生异步的 replicate.async_run
94
+ output = await replicate.async_run(self.model_name, input=prepared_input)
95
+
96
+ # 将结果转换为标准格式 (此部分逻辑无需改变)
97
+ if isinstance(output, list):
98
+ urls = output
99
+ else:
100
+ urls = [output]
101
+
102
+ result = {
103
+ "urls": urls,
104
+ "metadata": {
105
+ "model": self.model_name,
106
+ "input": input_data # 返回原始输入以供参考
107
+ }
108
+ }
109
+ logger.info(f"图像生成完成: {result['urls']}")
110
+ return result
111
+ except Exception as e:
112
+ logger.error(f"图像生成失败: {e}")
113
+ raise
114
+ finally:
115
+ # ★ 新增: 确保所有打开的文件都被关闭
116
+ for f in files_to_close:
117
+ f.close()
118
+
119
+ async def analyze_image(self, image_path: str, prompt: str) -> Dict[str, Any]:
120
+ """
121
+ 分析图像 (已优化为原生异步)
122
+ """
123
+ input_data = {"image": image_path, "prompt": prompt}
124
+ prepared_input, files_to_close = await self._prepare_input_files(input_data)
125
+ try:
126
+ logger.info(f"开始使用模型 {self.model_name} 分析图像 (原生异步)")
127
+ # ★ 调整点: 使用原生异步的 replicate.async_run
128
+ output = await replicate.async_run(self.model_name, input=prepared_input)
129
+
130
+ result = {
131
+ "text": "".join(output) if isinstance(output, list) else output,
132
+ "metadata": {
133
+ "model": self.model_name,
134
+ "input": input_data
135
+ }
136
+ }
137
+ logger.info(f"图像分析完成")
138
+ return result
139
+ except Exception as e:
140
+ logger.error(f"图像分析失败: {e}")
141
+ raise
142
+ finally:
143
+ # ★ 新增: 确保所有打开的文件都被关闭
144
+ for f in files_to_close:
145
+ f.close()
146
+
147
+ async def generate_and_save(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
148
+ """生成图像并保存到本地 (此方法无需修改)"""
149
+ result = await self.generate_image(input_data)
150
+ saved_paths = []
151
+ for i, url in enumerate(result["urls"]):
152
+ timestamp = int(time.time())
153
+ file_name = f"{self.output_dir}/{timestamp}_{uuid.uuid4().hex[:8]}_{i+1}.png"
154
+ try:
155
+ # Convert FileOutput object to string if necessary
156
+ url_str = str(url) if hasattr(url, "__str__") else url
157
+ await self._download_image(url_str, file_name)
158
+ saved_paths.append(file_name)
159
+ logger.info(f"图像已保存至: {file_name}")
160
+ except Exception as e:
161
+ logger.error(f"保存图像失败: {e}")
162
+ result["saved_paths"] = saved_paths
163
+ return result
164
+
165
+ async def _download_image(self, url: str, save_path: str) -> None:
166
+ """异步下载图像并保存 (此方法无需修改)"""
167
+ try:
168
+ async with aiohttp.ClientSession() as session:
169
+ async with session.get(url) as response:
170
+ response.raise_for_status()
171
+ content = await response.read()
172
+ with Image.open(BytesIO(content)) as img:
173
+ img.save(save_path)
174
+ except Exception as e:
175
+ logger.error(f"下载图像时出错: {url}, {e}")
176
+ raise
177
+
178
+ # `load` 和 `unload` 方法在Replicate API场景下通常是轻量级的
179
+ async def load(self) -> None:
180
+ if not self.api_token:
181
+ raise ValueError("缺少Replicate API令牌,请设置REPLICATE_API_TOKEN环境变量或在provider配置中提供")
182
+ logger.info(f"Replicate Vision服务已准备就绪,使用模型: {self.model_name}")
183
+
184
+ async def unload(self) -> None:
185
+ logger.info(f"卸载Replicate Vision服务: {self.model_name}")
@@ -0,0 +1,44 @@
1
+ """
2
+ ISA Model Training Framework
3
+
4
+ This module provides unified interfaces for training various types of AI models:
5
+ - LLM training with LlamaFactory
6
+ - Image model training with Flux/LoRA
7
+ - Model evaluation and benchmarking
8
+
9
+ Usage:
10
+ from isa_model.training import TrainingFactory
11
+
12
+ # Create training factory
13
+ factory = TrainingFactory()
14
+
15
+ # Fine-tune Gemma 3:4B
16
+ model_path = factory.finetune_llm(
17
+ model_name="gemma:4b",
18
+ dataset_path="path/to/data.json",
19
+ training_type="sft"
20
+ )
21
+ """
22
+
23
+ from .factory import TrainingFactory, finetune_gemma
24
+ from .engine.llama_factory import (
25
+ LlamaFactory,
26
+ LlamaFactoryConfig,
27
+ SFTConfig,
28
+ RLConfig,
29
+ DPOConfig,
30
+ TrainingStrategy,
31
+ DatasetFormat
32
+ )
33
+
34
+ __all__ = [
35
+ "TrainingFactory",
36
+ "finetune_gemma",
37
+ "LlamaFactory",
38
+ "LlamaFactoryConfig",
39
+ "SFTConfig",
40
+ "RLConfig",
41
+ "DPOConfig",
42
+ "TrainingStrategy",
43
+ "DatasetFormat"
44
+ ]
@@ -0,0 +1,393 @@
1
+ """
2
+ Unified Training Factory for ISA Model Framework
3
+
4
+ This factory provides a single interface for all training operations:
5
+ - LLM fine-tuning (SFT, DPO, RLHF)
6
+ - Image model training (Flux, LoRA)
7
+ - Model evaluation and benchmarking
8
+ """
9
+
10
+ import os
11
+ import logging
12
+ from typing import Optional, Dict, Any, Union, List
13
+ from pathlib import Path
14
+ import datetime
15
+
16
+ from .engine.llama_factory import LlamaFactory, TrainingStrategy, DatasetFormat
17
+ from .engine.llama_factory.config import SFTConfig, RLConfig, DPOConfig
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class TrainingFactory:
23
+ """
24
+ Unified factory for all AI model training operations.
25
+
26
+ This class provides simplified interfaces for:
27
+ - LLM training using LlamaFactory
28
+ - Image model training using Flux/LoRA
29
+ - Model evaluation and benchmarking
30
+
31
+ Example usage for fine-tuning Gemma 3:4B:
32
+ ```python
33
+ from isa_model.training import TrainingFactory
34
+
35
+ factory = TrainingFactory()
36
+
37
+ # Fine-tune with your dataset
38
+ model_path = factory.finetune_llm(
39
+ model_name="google/gemma-2-4b-it",
40
+ dataset_path="path/to/your/data.json",
41
+ training_type="sft",
42
+ use_lora=True,
43
+ num_epochs=3,
44
+ batch_size=4,
45
+ learning_rate=2e-5
46
+ )
47
+
48
+ # Train with DPO for preference optimization
49
+ dpo_model = factory.train_with_preferences(
50
+ model_path=model_path,
51
+ preference_data="path/to/preferences.json",
52
+ beta=0.1
53
+ )
54
+ ```
55
+ """
56
+
57
+ def __init__(self, base_output_dir: Optional[str] = None):
58
+ """
59
+ Initialize the training factory.
60
+
61
+ Args:
62
+ base_output_dir: Base directory for all training outputs
63
+ """
64
+ self.base_output_dir = base_output_dir or os.path.join(os.getcwd(), "training_outputs")
65
+ os.makedirs(self.base_output_dir, exist_ok=True)
66
+
67
+ # Initialize sub-factories
68
+ self.llm_factory = LlamaFactory(base_output_dir=os.path.join(self.base_output_dir, "llm"))
69
+
70
+ logger.info(f"TrainingFactory initialized with output dir: {self.base_output_dir}")
71
+
72
+ def _get_output_dir(self, model_name: str, training_type: str) -> str:
73
+ """Generate timestamped output directory."""
74
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
75
+ safe_model_name = model_name.replace("/", "_").replace(":", "_")
76
+ return os.path.join(self.base_output_dir, f"{safe_model_name}_{training_type}_{timestamp}")
77
+
78
+ # =================
79
+ # LLM Training Methods
80
+ # =================
81
+
82
+ def finetune_llm(
83
+ self,
84
+ model_name: str,
85
+ dataset_path: str,
86
+ training_type: str = "sft",
87
+ output_dir: Optional[str] = None,
88
+ dataset_format: str = "alpaca",
89
+ use_lora: bool = True,
90
+ batch_size: int = 4,
91
+ num_epochs: int = 3,
92
+ learning_rate: float = 2e-5,
93
+ max_length: int = 1024,
94
+ lora_rank: int = 8,
95
+ lora_alpha: int = 16,
96
+ val_dataset_path: Optional[str] = None,
97
+ **kwargs
98
+ ) -> str:
99
+ """
100
+ Fine-tune an LLM model.
101
+
102
+ Args:
103
+ model_name: Model identifier (e.g., "google/gemma-2-4b-it", "meta-llama/Llama-2-7b-hf")
104
+ dataset_path: Path to training dataset
105
+ training_type: Type of training ("sft", "dpo", "rlhf")
106
+ output_dir: Custom output directory
107
+ dataset_format: Dataset format ("alpaca", "sharegpt", "custom")
108
+ use_lora: Whether to use LoRA for efficient training
109
+ batch_size: Training batch size
110
+ num_epochs: Number of training epochs
111
+ learning_rate: Learning rate
112
+ max_length: Maximum sequence length
113
+ lora_rank: LoRA rank parameter
114
+ lora_alpha: LoRA alpha parameter
115
+ val_dataset_path: Path to validation dataset (optional)
116
+ **kwargs: Additional training parameters
117
+
118
+ Returns:
119
+ Path to the trained model
120
+
121
+ Example:
122
+ ```python
123
+ # Fine-tune Gemma 3:4B with your dataset
124
+ model_path = factory.finetune_llm(
125
+ model_name="google/gemma-2-4b-it",
126
+ dataset_path="my_training_data.json",
127
+ training_type="sft",
128
+ use_lora=True,
129
+ num_epochs=3,
130
+ batch_size=4
131
+ )
132
+ ```
133
+ """
134
+ if not output_dir:
135
+ output_dir = self._get_output_dir(model_name, training_type)
136
+
137
+ # Convert format string to enum
138
+ format_map = {
139
+ "alpaca": DatasetFormat.ALPACA,
140
+ "sharegpt": DatasetFormat.SHAREGPT,
141
+ "custom": DatasetFormat.CUSTOM
142
+ }
143
+ dataset_format_enum = format_map.get(dataset_format, DatasetFormat.ALPACA)
144
+
145
+ if training_type.lower() == "sft":
146
+ return self.llm_factory.finetune(
147
+ model_path=model_name,
148
+ train_data=dataset_path,
149
+ val_data=val_dataset_path,
150
+ output_dir=output_dir,
151
+ dataset_format=dataset_format_enum,
152
+ use_lora=use_lora,
153
+ batch_size=batch_size,
154
+ num_epochs=num_epochs,
155
+ learning_rate=learning_rate,
156
+ max_length=max_length,
157
+ lora_rank=lora_rank,
158
+ lora_alpha=lora_alpha,
159
+ **kwargs
160
+ )
161
+ else:
162
+ raise ValueError(f"Training type '{training_type}' not supported yet. Use 'sft' for now.")
163
+
164
+ def train_with_preferences(
165
+ self,
166
+ model_path: str,
167
+ preference_data: str,
168
+ output_dir: Optional[str] = None,
169
+ reference_model: Optional[str] = None,
170
+ beta: float = 0.1,
171
+ use_lora: bool = True,
172
+ batch_size: int = 4,
173
+ num_epochs: int = 3,
174
+ learning_rate: float = 5e-6,
175
+ val_data: Optional[str] = None,
176
+ **kwargs
177
+ ) -> str:
178
+ """
179
+ Train model with preference data using DPO.
180
+
181
+ Args:
182
+ model_path: Path to the base model
183
+ preference_data: Path to preference dataset
184
+ output_dir: Custom output directory
185
+ reference_model: Reference model for DPO (optional)
186
+ beta: DPO beta parameter
187
+ use_lora: Whether to use LoRA
188
+ batch_size: Training batch size
189
+ num_epochs: Number of epochs
190
+ learning_rate: Learning rate
191
+ val_data: Validation data path
192
+ **kwargs: Additional parameters
193
+
194
+ Returns:
195
+ Path to the trained model
196
+ """
197
+ if not output_dir:
198
+ model_name = os.path.basename(model_path)
199
+ output_dir = self._get_output_dir(model_name, "dpo")
200
+
201
+ return self.llm_factory.dpo(
202
+ model_path=model_path,
203
+ train_data=preference_data,
204
+ val_data=val_data,
205
+ reference_model=reference_model,
206
+ output_dir=output_dir,
207
+ use_lora=use_lora,
208
+ batch_size=batch_size,
209
+ num_epochs=num_epochs,
210
+ learning_rate=learning_rate,
211
+ beta=beta,
212
+ **kwargs
213
+ )
214
+
215
+ def train_reward_model(
216
+ self,
217
+ model_path: str,
218
+ reward_data: str,
219
+ output_dir: Optional[str] = None,
220
+ use_lora: bool = True,
221
+ batch_size: int = 8,
222
+ num_epochs: int = 3,
223
+ learning_rate: float = 1e-5,
224
+ val_data: Optional[str] = None,
225
+ **kwargs
226
+ ) -> str:
227
+ """
228
+ Train a reward model for RLHF.
229
+
230
+ Args:
231
+ model_path: Base model path
232
+ reward_data: Reward training data
233
+ output_dir: Output directory
234
+ use_lora: Whether to use LoRA
235
+ batch_size: Batch size
236
+ num_epochs: Number of epochs
237
+ learning_rate: Learning rate
238
+ val_data: Validation data
239
+ **kwargs: Additional parameters
240
+
241
+ Returns:
242
+ Path to trained reward model
243
+ """
244
+ if not output_dir:
245
+ model_name = os.path.basename(model_path)
246
+ output_dir = self._get_output_dir(model_name, "reward")
247
+
248
+ return self.llm_factory.train_reward_model(
249
+ model_path=model_path,
250
+ train_data=reward_data,
251
+ val_data=val_data,
252
+ output_dir=output_dir,
253
+ use_lora=use_lora,
254
+ batch_size=batch_size,
255
+ num_epochs=num_epochs,
256
+ learning_rate=learning_rate,
257
+ **kwargs
258
+ )
259
+
260
+ # =================
261
+ # Image Model Training Methods
262
+ # =================
263
+
264
+ def train_image_model(
265
+ self,
266
+ model_type: str = "flux",
267
+ training_images_dir: str = "",
268
+ output_dir: Optional[str] = None,
269
+ use_lora: bool = True,
270
+ num_epochs: int = 1000,
271
+ batch_size: int = 1,
272
+ learning_rate: float = 1e-4,
273
+ **kwargs
274
+ ) -> str:
275
+ """
276
+ Train an image generation model.
277
+
278
+ Args:
279
+ model_type: Type of model ("flux", "lora")
280
+ training_images_dir: Directory containing training images
281
+ output_dir: Output directory
282
+ use_lora: Whether to use LoRA
283
+ num_epochs: Training epochs
284
+ batch_size: Batch size
285
+ learning_rate: Learning rate
286
+ **kwargs: Additional parameters
287
+
288
+ Returns:
289
+ Path to trained model
290
+ """
291
+ if not output_dir:
292
+ output_dir = self._get_output_dir("image_model", model_type)
293
+
294
+ # TODO: Implement image model training
295
+ logger.warning("Image model training not fully implemented yet")
296
+ return output_dir
297
+
298
+ # =================
299
+ # Utility Methods
300
+ # =================
301
+
302
+ def get_training_status(self, output_dir: str) -> Dict[str, Any]:
303
+ """
304
+ Get training status from output directory.
305
+
306
+ Args:
307
+ output_dir: Training output directory
308
+
309
+ Returns:
310
+ Dictionary with training status information
311
+ """
312
+ status = {
313
+ "output_dir": output_dir,
314
+ "exists": os.path.exists(output_dir),
315
+ "files": []
316
+ }
317
+
318
+ if status["exists"]:
319
+ status["files"] = os.listdir(output_dir)
320
+
321
+ return status
322
+
323
+ def list_trained_models(self) -> List[Dict[str, Any]]:
324
+ """
325
+ List all trained models in the output directory.
326
+
327
+ Returns:
328
+ List of model information dictionaries
329
+ """
330
+ models = []
331
+
332
+ if os.path.exists(self.base_output_dir):
333
+ for item in os.listdir(self.base_output_dir):
334
+ item_path = os.path.join(self.base_output_dir, item)
335
+ if os.path.isdir(item_path):
336
+ models.append({
337
+ "name": item,
338
+ "path": item_path,
339
+ "created": datetime.datetime.fromtimestamp(
340
+ os.path.getctime(item_path)
341
+ ).isoformat()
342
+ })
343
+
344
+ return sorted(models, key=lambda x: x["created"], reverse=True)
345
+
346
+
347
+ # Convenience functions for quick access
348
+ def finetune_gemma(
349
+ dataset_path: str,
350
+ model_size: str = "4b",
351
+ output_dir: Optional[str] = None,
352
+ **kwargs
353
+ ) -> str:
354
+ """
355
+ Quick function to fine-tune Gemma models.
356
+
357
+ Args:
358
+ dataset_path: Path to training dataset
359
+ model_size: Model size ("2b", "4b", "7b")
360
+ output_dir: Output directory
361
+ **kwargs: Additional training parameters
362
+
363
+ Returns:
364
+ Path to fine-tuned model
365
+
366
+ Example:
367
+ ```python
368
+ from isa_model.training import finetune_gemma
369
+
370
+ model_path = finetune_gemma(
371
+ dataset_path="my_data.json",
372
+ model_size="4b",
373
+ num_epochs=3,
374
+ batch_size=4
375
+ )
376
+ ```
377
+ """
378
+ factory = TrainingFactory()
379
+
380
+ model_map = {
381
+ "2b": "google/gemma-2-2b-it",
382
+ "4b": "google/gemma-2-4b-it",
383
+ "7b": "google/gemma-2-7b-it"
384
+ }
385
+
386
+ model_name = model_map.get(model_size, "google/gemma-2-4b-it")
387
+
388
+ return factory.finetune_llm(
389
+ model_name=model_name,
390
+ dataset_path=dataset_path,
391
+ output_dir=output_dir,
392
+ **kwargs
393
+ )