isa-model 0.1.1__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +142 -240
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/services/audio/openai_tts_service.py +104 -3
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +111 -1
- isa_model/inference/services/llm/ollama_llm_service.py +234 -26
- isa_model/inference/services/llm/openai_llm_service.py +225 -28
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/ollama_vision_service.py +143 -17
- isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.2.8.dist-info/METADATA +465 -0
- isa_model-0.2.8.dist-info/RECORD +86 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.1.1.dist-info/METADATA +0 -327
- isa_model-0.1.1.dist-info/RECORD +0 -92
- isa_model-0.1.1.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/WHEEL +0 -0
- {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,720 @@
|
|
1
|
+
"""
|
2
|
+
Enhanced Multi-Modal Training Framework for ISA Model SDK
|
3
|
+
|
4
|
+
Supports training for:
|
5
|
+
- LLM models (GPT, Gemma, Llama, etc.) with Unsloth acceleration
|
6
|
+
- Stable Diffusion models
|
7
|
+
- Traditional ML models (scikit-learn, XGBoost, etc.)
|
8
|
+
- Computer Vision models (CNN, Vision Transformers)
|
9
|
+
- Audio models (Whisper, etc.)
|
10
|
+
"""
|
11
|
+
|
12
|
+
import os
|
13
|
+
import json
|
14
|
+
import logging
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Optional, Dict, Any, List, Union, Tuple
|
17
|
+
from pathlib import Path
|
18
|
+
import datetime
|
19
|
+
|
20
|
+
try:
|
21
|
+
import torch
|
22
|
+
import torch.nn as nn
|
23
|
+
from transformers import (
|
24
|
+
AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification,
|
25
|
+
Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
26
|
+
)
|
27
|
+
from peft import LoraConfig, get_peft_model, TaskType
|
28
|
+
from datasets import Dataset
|
29
|
+
HF_AVAILABLE = True
|
30
|
+
except ImportError:
|
31
|
+
HF_AVAILABLE = False
|
32
|
+
|
33
|
+
try:
|
34
|
+
from unsloth import FastLanguageModel
|
35
|
+
from unsloth.trainer import UnslothTrainer
|
36
|
+
UNSLOTH_AVAILABLE = True
|
37
|
+
except ImportError:
|
38
|
+
UNSLOTH_AVAILABLE = False
|
39
|
+
|
40
|
+
try:
|
41
|
+
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
42
|
+
from diffusers.training_utils import EMAModel
|
43
|
+
DIFFUSERS_AVAILABLE = True
|
44
|
+
except ImportError:
|
45
|
+
DIFFUSERS_AVAILABLE = False
|
46
|
+
|
47
|
+
try:
|
48
|
+
import sklearn
|
49
|
+
from sklearn.base import BaseEstimator
|
50
|
+
import xgboost as xgb
|
51
|
+
SKLEARN_AVAILABLE = True
|
52
|
+
except ImportError:
|
53
|
+
SKLEARN_AVAILABLE = False
|
54
|
+
|
55
|
+
from .config import TrainingConfig, LoRAConfig, DatasetConfig
|
56
|
+
|
57
|
+
logger = logging.getLogger(__name__)
|
58
|
+
|
59
|
+
# Unsloth supported models
|
60
|
+
UNSLOTH_SUPPORTED_MODELS = [
|
61
|
+
"google/gemma-2-2b",
|
62
|
+
"google/gemma-2-2b-it",
|
63
|
+
"google/gemma-2-4b",
|
64
|
+
"google/gemma-2-4b-it",
|
65
|
+
"google/gemma-2-7b",
|
66
|
+
"google/gemma-2-7b-it",
|
67
|
+
"meta-llama/Llama-2-7b-hf",
|
68
|
+
"meta-llama/Llama-2-7b-chat-hf",
|
69
|
+
"meta-llama/Llama-2-13b-hf",
|
70
|
+
"meta-llama/Llama-2-13b-chat-hf",
|
71
|
+
"mistralai/Mistral-7B-v0.1",
|
72
|
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
73
|
+
"microsoft/DialoGPT-medium",
|
74
|
+
"microsoft/DialoGPT-large",
|
75
|
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
76
|
+
]
|
77
|
+
|
78
|
+
|
79
|
+
class BaseTrainer(ABC):
|
80
|
+
"""
|
81
|
+
Abstract base class for all trainers in the ISA Model SDK.
|
82
|
+
|
83
|
+
This class defines the common interface that all trainers must implement,
|
84
|
+
regardless of the model type (LLM, Stable Diffusion, ML, etc.).
|
85
|
+
"""
|
86
|
+
|
87
|
+
def __init__(self, config: TrainingConfig):
|
88
|
+
"""
|
89
|
+
Initialize the base trainer.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
config: Training configuration object
|
93
|
+
"""
|
94
|
+
self.config = config
|
95
|
+
self.model = None
|
96
|
+
self.tokenizer = None
|
97
|
+
self.dataset = None
|
98
|
+
self.training_args = None
|
99
|
+
|
100
|
+
# Create output directory
|
101
|
+
os.makedirs(config.output_dir, exist_ok=True)
|
102
|
+
|
103
|
+
# Setup comprehensive logging
|
104
|
+
self._setup_logging()
|
105
|
+
|
106
|
+
logger.info(f"Initialized {self.__class__.__name__} with config: {config.model_name}")
|
107
|
+
logger.info(f"Training configuration: {config.to_dict()}")
|
108
|
+
|
109
|
+
def _setup_logging(self):
|
110
|
+
"""Setup comprehensive logging for training process"""
|
111
|
+
log_dir = Path(self.config.output_dir) / "logs"
|
112
|
+
log_dir.mkdir(exist_ok=True)
|
113
|
+
|
114
|
+
# Create formatters
|
115
|
+
detailed_formatter = logging.Formatter(
|
116
|
+
'%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
|
117
|
+
)
|
118
|
+
|
119
|
+
# File handler for detailed logs
|
120
|
+
file_handler = logging.FileHandler(log_dir / 'training_detailed.log')
|
121
|
+
file_handler.setLevel(logging.DEBUG)
|
122
|
+
file_handler.setFormatter(detailed_formatter)
|
123
|
+
|
124
|
+
# File handler for errors only
|
125
|
+
error_handler = logging.FileHandler(log_dir / 'training_errors.log')
|
126
|
+
error_handler.setLevel(logging.ERROR)
|
127
|
+
error_handler.setFormatter(detailed_formatter)
|
128
|
+
|
129
|
+
# Console handler for important info
|
130
|
+
console_handler = logging.StreamHandler()
|
131
|
+
console_handler.setLevel(logging.INFO)
|
132
|
+
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
133
|
+
|
134
|
+
# Configure logger
|
135
|
+
logger.addHandler(file_handler)
|
136
|
+
logger.addHandler(error_handler)
|
137
|
+
logger.addHandler(console_handler)
|
138
|
+
logger.setLevel(logging.DEBUG)
|
139
|
+
|
140
|
+
@abstractmethod
|
141
|
+
def load_model(self) -> None:
|
142
|
+
"""Load the model and tokenizer."""
|
143
|
+
pass
|
144
|
+
|
145
|
+
@abstractmethod
|
146
|
+
def prepare_dataset(self) -> None:
|
147
|
+
"""Prepare the training dataset."""
|
148
|
+
pass
|
149
|
+
|
150
|
+
@abstractmethod
|
151
|
+
def setup_training(self) -> None:
|
152
|
+
"""Setup training arguments and trainer."""
|
153
|
+
pass
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def train(self) -> str:
|
157
|
+
"""Execute the training process."""
|
158
|
+
pass
|
159
|
+
|
160
|
+
@abstractmethod
|
161
|
+
def save_model(self, output_path: str) -> None:
|
162
|
+
"""Save the trained model."""
|
163
|
+
pass
|
164
|
+
|
165
|
+
def validate_config(self) -> List[str]:
|
166
|
+
"""Validate the training configuration."""
|
167
|
+
logger.debug("Validating training configuration...")
|
168
|
+
issues = []
|
169
|
+
|
170
|
+
if not self.config.model_name:
|
171
|
+
issues.append("model_name is required")
|
172
|
+
|
173
|
+
if not self.config.output_dir:
|
174
|
+
issues.append("output_dir is required")
|
175
|
+
|
176
|
+
if self.config.num_epochs <= 0:
|
177
|
+
issues.append("num_epochs must be positive")
|
178
|
+
|
179
|
+
if self.config.batch_size <= 0:
|
180
|
+
issues.append("batch_size must be positive")
|
181
|
+
|
182
|
+
if issues:
|
183
|
+
logger.error(f"Configuration validation failed: {issues}")
|
184
|
+
else:
|
185
|
+
logger.info("Configuration validation passed")
|
186
|
+
|
187
|
+
return issues
|
188
|
+
|
189
|
+
def save_training_config(self) -> None:
|
190
|
+
"""Save the training configuration to output directory."""
|
191
|
+
config_path = os.path.join(self.config.output_dir, "training_config.json")
|
192
|
+
with open(config_path, 'w') as f:
|
193
|
+
json.dump(self.config.to_dict(), f, indent=2)
|
194
|
+
logger.info(f"Training config saved to: {config_path}")
|
195
|
+
|
196
|
+
|
197
|
+
class LLMTrainer(BaseTrainer):
|
198
|
+
"""
|
199
|
+
Trainer for Large Language Models using HuggingFace Transformers with Unsloth acceleration.
|
200
|
+
|
201
|
+
Supports:
|
202
|
+
- Supervised Fine-Tuning (SFT)
|
203
|
+
- LoRA (Low-Rank Adaptation)
|
204
|
+
- Unsloth acceleration (2x faster, 50% less memory)
|
205
|
+
- Full parameter training
|
206
|
+
- Instruction tuning
|
207
|
+
"""
|
208
|
+
|
209
|
+
def __init__(self, config: TrainingConfig):
|
210
|
+
super().__init__(config)
|
211
|
+
|
212
|
+
if not HF_AVAILABLE:
|
213
|
+
raise ImportError("HuggingFace transformers not available. Install with: pip install transformers")
|
214
|
+
|
215
|
+
self.trainer = None
|
216
|
+
self.data_collator = None
|
217
|
+
self.use_unsloth = self._should_use_unsloth()
|
218
|
+
|
219
|
+
logger.info(f"LLM Trainer initialized - Unsloth: {'✅ Enabled' if self.use_unsloth else '❌ Disabled'}")
|
220
|
+
if self.use_unsloth and not UNSLOTH_AVAILABLE:
|
221
|
+
logger.warning("Unsloth requested but not available. Install with: pip install unsloth")
|
222
|
+
self.use_unsloth = False
|
223
|
+
|
224
|
+
def _should_use_unsloth(self) -> bool:
|
225
|
+
"""Determine if Unsloth should be used for this model"""
|
226
|
+
if not UNSLOTH_AVAILABLE:
|
227
|
+
return False
|
228
|
+
|
229
|
+
# Check if model is supported by Unsloth
|
230
|
+
model_name = self.config.model_name.lower()
|
231
|
+
for supported_model in UNSLOTH_SUPPORTED_MODELS:
|
232
|
+
if supported_model.lower() in model_name or model_name in supported_model.lower():
|
233
|
+
logger.info(f"Model {self.config.model_name} is supported by Unsloth")
|
234
|
+
return True
|
235
|
+
|
236
|
+
logger.info(f"Model {self.config.model_name} not in Unsloth supported list, using standard training")
|
237
|
+
return False
|
238
|
+
|
239
|
+
def load_model(self) -> None:
|
240
|
+
"""Load the LLM model and tokenizer with optional Unsloth acceleration."""
|
241
|
+
logger.info(f"Loading model: {self.config.model_name}")
|
242
|
+
logger.debug(f"Using Unsloth: {self.use_unsloth}")
|
243
|
+
|
244
|
+
try:
|
245
|
+
if self.use_unsloth:
|
246
|
+
self._load_model_with_unsloth()
|
247
|
+
else:
|
248
|
+
self._load_model_standard()
|
249
|
+
|
250
|
+
logger.info("Model and tokenizer loaded successfully")
|
251
|
+
|
252
|
+
except Exception as e:
|
253
|
+
logger.error(f"Failed to load model: {e}")
|
254
|
+
raise
|
255
|
+
|
256
|
+
def _load_model_with_unsloth(self) -> None:
|
257
|
+
"""Load model using Unsloth for acceleration"""
|
258
|
+
logger.info("Loading model with Unsloth acceleration...")
|
259
|
+
|
260
|
+
# Unsloth model loading
|
261
|
+
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
262
|
+
model_name=self.config.model_name,
|
263
|
+
max_seq_length=self.config.dataset_config.max_length if self.config.dataset_config else 1024,
|
264
|
+
dtype=None, # Auto-detect
|
265
|
+
load_in_4bit=True, # Use 4-bit quantization for memory efficiency
|
266
|
+
)
|
267
|
+
|
268
|
+
# Setup LoRA with Unsloth
|
269
|
+
if self.config.lora_config and self.config.lora_config.use_lora:
|
270
|
+
logger.info("Setting up LoRA with Unsloth...")
|
271
|
+
lora_config = self.config.lora_config
|
272
|
+
self.model = FastLanguageModel.get_peft_model(
|
273
|
+
self.model,
|
274
|
+
r=lora_config.lora_rank,
|
275
|
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
276
|
+
lora_alpha=lora_config.lora_alpha,
|
277
|
+
lora_dropout=lora_config.lora_dropout,
|
278
|
+
bias="none",
|
279
|
+
use_gradient_checkpointing="unsloth", # Unsloth's optimized gradient checkpointing
|
280
|
+
random_state=3407,
|
281
|
+
use_rslora=False, # Rank stabilized LoRA
|
282
|
+
loftq_config=None, # LoftQ
|
283
|
+
)
|
284
|
+
|
285
|
+
logger.info("Unsloth model loaded successfully")
|
286
|
+
|
287
|
+
def _load_model_standard(self) -> None:
|
288
|
+
"""Load model using standard HuggingFace transformers"""
|
289
|
+
logger.info("Loading model with standard HuggingFace transformers...")
|
290
|
+
|
291
|
+
# Load tokenizer
|
292
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
293
|
+
self.config.model_name,
|
294
|
+
trust_remote_code=True,
|
295
|
+
padding_side="right"
|
296
|
+
)
|
297
|
+
|
298
|
+
# Add pad token if missing
|
299
|
+
if self.tokenizer.pad_token is None:
|
300
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
301
|
+
logger.debug("Added pad token to tokenizer")
|
302
|
+
|
303
|
+
# Load model
|
304
|
+
model_kwargs = {
|
305
|
+
"trust_remote_code": True,
|
306
|
+
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
|
307
|
+
"device_map": "auto" if torch.cuda.is_available() else None
|
308
|
+
}
|
309
|
+
|
310
|
+
logger.debug(f"Model loading kwargs: {model_kwargs}")
|
311
|
+
|
312
|
+
if self.config.training_type == "classification":
|
313
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
314
|
+
self.config.model_name, **model_kwargs
|
315
|
+
)
|
316
|
+
else:
|
317
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
318
|
+
self.config.model_name, **model_kwargs
|
319
|
+
)
|
320
|
+
|
321
|
+
# Setup LoRA if enabled
|
322
|
+
if self.config.lora_config and self.config.lora_config.use_lora:
|
323
|
+
self._setup_lora()
|
324
|
+
|
325
|
+
logger.info("Standard model loaded successfully")
|
326
|
+
|
327
|
+
def _setup_lora(self) -> None:
|
328
|
+
"""Setup LoRA configuration for standard training"""
|
329
|
+
logger.info("Setting up LoRA configuration...")
|
330
|
+
|
331
|
+
lora_config = LoraConfig(
|
332
|
+
r=self.config.lora_config.lora_rank,
|
333
|
+
lora_alpha=self.config.lora_config.lora_alpha,
|
334
|
+
target_modules=self.config.lora_config.lora_target_modules,
|
335
|
+
lora_dropout=self.config.lora_config.lora_dropout,
|
336
|
+
bias="none",
|
337
|
+
task_type=TaskType.CAUSAL_LM if self.config.training_type != "classification" else TaskType.SEQ_CLS
|
338
|
+
)
|
339
|
+
|
340
|
+
self.model = get_peft_model(self.model, lora_config)
|
341
|
+
self.model.print_trainable_parameters()
|
342
|
+
logger.info("LoRA configuration applied successfully")
|
343
|
+
|
344
|
+
def prepare_dataset(self) -> None:
|
345
|
+
"""Prepare the training dataset."""
|
346
|
+
logger.info("Preparing training dataset...")
|
347
|
+
|
348
|
+
try:
|
349
|
+
from .dataset import DatasetManager
|
350
|
+
|
351
|
+
if not self.config.dataset_config:
|
352
|
+
raise ValueError("Dataset configuration is required")
|
353
|
+
|
354
|
+
dataset_manager = DatasetManager(
|
355
|
+
self.tokenizer,
|
356
|
+
max_length=self.config.dataset_config.max_length
|
357
|
+
)
|
358
|
+
|
359
|
+
train_dataset, eval_dataset = dataset_manager.prepare_dataset(
|
360
|
+
dataset_path=self.config.dataset_config.dataset_path,
|
361
|
+
dataset_format=self.config.dataset_config.dataset_format,
|
362
|
+
validation_split=self.config.dataset_config.validation_split
|
363
|
+
)
|
364
|
+
|
365
|
+
self.dataset = {
|
366
|
+
'train': train_dataset,
|
367
|
+
'validation': eval_dataset
|
368
|
+
}
|
369
|
+
|
370
|
+
# Setup data collator
|
371
|
+
if self.config.training_type == "classification":
|
372
|
+
self.data_collator = None # Use default
|
373
|
+
else:
|
374
|
+
self.data_collator = DataCollatorForLanguageModeling(
|
375
|
+
tokenizer=self.tokenizer,
|
376
|
+
mlm=False
|
377
|
+
)
|
378
|
+
|
379
|
+
logger.info(f"Dataset prepared - Train: {len(train_dataset)} samples")
|
380
|
+
if eval_dataset:
|
381
|
+
logger.info(f"Validation: {len(eval_dataset)} samples")
|
382
|
+
|
383
|
+
except Exception as e:
|
384
|
+
logger.error(f"Failed to prepare dataset: {e}")
|
385
|
+
raise
|
386
|
+
|
387
|
+
def setup_training(self) -> None:
|
388
|
+
"""Setup training arguments and trainer."""
|
389
|
+
logger.info("Setting up training configuration...")
|
390
|
+
|
391
|
+
try:
|
392
|
+
# Calculate training steps
|
393
|
+
total_steps = len(self.dataset['train']) // (self.config.batch_size * self.config.gradient_accumulation_steps) * self.config.num_epochs
|
394
|
+
|
395
|
+
logger.debug(f"Total training steps: {total_steps}")
|
396
|
+
|
397
|
+
self.training_args = TrainingArguments(
|
398
|
+
output_dir=self.config.output_dir,
|
399
|
+
num_train_epochs=self.config.num_epochs,
|
400
|
+
per_device_train_batch_size=self.config.batch_size,
|
401
|
+
per_device_eval_batch_size=self.config.batch_size,
|
402
|
+
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
403
|
+
learning_rate=self.config.learning_rate,
|
404
|
+
weight_decay=self.config.weight_decay,
|
405
|
+
warmup_steps=max(1, int(0.1 * total_steps)), # 10% warmup
|
406
|
+
logging_steps=max(1, total_steps // 100), # Log 100 times per training
|
407
|
+
eval_strategy="steps" if self.dataset.get('validation') else "no",
|
408
|
+
eval_steps=max(1, total_steps // 10) if self.dataset.get('validation') else None,
|
409
|
+
save_strategy="steps",
|
410
|
+
save_steps=max(1, total_steps // 5), # Save 5 times per training
|
411
|
+
save_total_limit=3,
|
412
|
+
load_best_model_at_end=True if self.dataset.get('validation') else False,
|
413
|
+
metric_for_best_model="eval_loss" if self.dataset.get('validation') else None,
|
414
|
+
greater_is_better=False,
|
415
|
+
report_to=None, # Disable wandb/tensorboard by default
|
416
|
+
remove_unused_columns=False,
|
417
|
+
dataloader_pin_memory=False,
|
418
|
+
fp16=torch.cuda.is_available() and not self.use_unsloth, # Unsloth handles precision
|
419
|
+
gradient_checkpointing=True and not self.use_unsloth, # Unsloth handles checkpointing
|
420
|
+
optim="adamw_torch",
|
421
|
+
lr_scheduler_type="cosine",
|
422
|
+
logging_dir=os.path.join(self.config.output_dir, "logs"),
|
423
|
+
run_name=f"training_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
424
|
+
)
|
425
|
+
|
426
|
+
# Initialize trainer
|
427
|
+
if self.use_unsloth:
|
428
|
+
logger.info("Initializing Unsloth trainer...")
|
429
|
+
self.trainer = UnslothTrainer(
|
430
|
+
model=self.model,
|
431
|
+
tokenizer=self.tokenizer,
|
432
|
+
train_dataset=self.dataset['train'],
|
433
|
+
eval_dataset=self.dataset.get('validation'),
|
434
|
+
args=self.training_args,
|
435
|
+
data_collator=self.data_collator,
|
436
|
+
)
|
437
|
+
else:
|
438
|
+
logger.info("Initializing standard trainer...")
|
439
|
+
self.trainer = Trainer(
|
440
|
+
model=self.model,
|
441
|
+
args=self.training_args,
|
442
|
+
train_dataset=self.dataset['train'],
|
443
|
+
eval_dataset=self.dataset.get('validation'),
|
444
|
+
tokenizer=self.tokenizer,
|
445
|
+
data_collator=self.data_collator
|
446
|
+
)
|
447
|
+
|
448
|
+
logger.info("Training setup completed successfully")
|
449
|
+
|
450
|
+
except Exception as e:
|
451
|
+
logger.error(f"Failed to setup training: {e}")
|
452
|
+
raise
|
453
|
+
|
454
|
+
def train(self) -> str:
|
455
|
+
"""Execute the training process."""
|
456
|
+
logger.info("=" * 60)
|
457
|
+
logger.info("STARTING LLM TRAINING")
|
458
|
+
logger.info("=" * 60)
|
459
|
+
|
460
|
+
try:
|
461
|
+
# Validate configuration
|
462
|
+
issues = self.validate_config()
|
463
|
+
if issues:
|
464
|
+
raise ValueError(f"Configuration issues: {issues}")
|
465
|
+
|
466
|
+
# Load model and prepare dataset
|
467
|
+
logger.info("Step 1/5: Loading model...")
|
468
|
+
self.load_model()
|
469
|
+
|
470
|
+
logger.info("Step 2/5: Preparing dataset...")
|
471
|
+
self.prepare_dataset()
|
472
|
+
|
473
|
+
logger.info("Step 3/5: Setting up training...")
|
474
|
+
self.setup_training()
|
475
|
+
|
476
|
+
# Save training config
|
477
|
+
self.save_training_config()
|
478
|
+
|
479
|
+
logger.info("Step 4/5: Starting training...")
|
480
|
+
logger.info(f"Training with {'Unsloth acceleration' if self.use_unsloth else 'standard HuggingFace'}")
|
481
|
+
|
482
|
+
# Start training
|
483
|
+
train_result = self.trainer.train()
|
484
|
+
|
485
|
+
logger.info("Step 5/5: Saving model...")
|
486
|
+
# Save final model
|
487
|
+
final_model_path = os.path.join(self.config.output_dir, "final_model")
|
488
|
+
self.save_model(final_model_path)
|
489
|
+
|
490
|
+
# Save training metrics
|
491
|
+
metrics_path = os.path.join(self.config.output_dir, "training_metrics.json")
|
492
|
+
with open(metrics_path, 'w') as f:
|
493
|
+
json.dump(train_result.metrics, f, indent=2)
|
494
|
+
|
495
|
+
logger.info("=" * 60)
|
496
|
+
logger.info("TRAINING COMPLETED SUCCESSFULLY!")
|
497
|
+
logger.info("=" * 60)
|
498
|
+
logger.info(f"Model saved to: {final_model_path}")
|
499
|
+
logger.info(f"Training metrics saved to: {metrics_path}")
|
500
|
+
|
501
|
+
return final_model_path
|
502
|
+
|
503
|
+
except Exception as e:
|
504
|
+
logger.error("=" * 60)
|
505
|
+
logger.error("TRAINING FAILED!")
|
506
|
+
logger.error("=" * 60)
|
507
|
+
logger.error(f"Error: {e}")
|
508
|
+
logger.error("Check the error logs for detailed information")
|
509
|
+
raise
|
510
|
+
|
511
|
+
def save_model(self, output_path: str) -> None:
|
512
|
+
"""Save the trained model."""
|
513
|
+
logger.info(f"Saving model to: {output_path}")
|
514
|
+
|
515
|
+
try:
|
516
|
+
os.makedirs(output_path, exist_ok=True)
|
517
|
+
|
518
|
+
# Save model and tokenizer
|
519
|
+
self.trainer.save_model(output_path)
|
520
|
+
self.tokenizer.save_pretrained(output_path)
|
521
|
+
|
522
|
+
# Save LoRA adapters if used
|
523
|
+
if self.config.lora_config and self.config.lora_config.use_lora:
|
524
|
+
adapter_path = os.path.join(output_path, "adapter_model")
|
525
|
+
if hasattr(self.model, 'save_pretrained'):
|
526
|
+
self.model.save_pretrained(adapter_path)
|
527
|
+
logger.info(f"LoRA adapters saved to: {adapter_path}")
|
528
|
+
|
529
|
+
# Save additional metadata
|
530
|
+
metadata = {
|
531
|
+
"model_name": self.config.model_name,
|
532
|
+
"training_type": self.config.training_type,
|
533
|
+
"use_unsloth": self.use_unsloth,
|
534
|
+
"use_lora": self.config.lora_config.use_lora if self.config.lora_config else False,
|
535
|
+
"saved_at": datetime.datetime.now().isoformat(),
|
536
|
+
"config": self.config.to_dict()
|
537
|
+
}
|
538
|
+
|
539
|
+
with open(os.path.join(output_path, "training_metadata.json"), 'w') as f:
|
540
|
+
json.dump(metadata, f, indent=2)
|
541
|
+
|
542
|
+
logger.info(f"Model saved successfully to: {output_path}")
|
543
|
+
|
544
|
+
except Exception as e:
|
545
|
+
logger.error(f"Failed to save model: {e}")
|
546
|
+
raise
|
547
|
+
|
548
|
+
|
549
|
+
class StableDiffusionTrainer(BaseTrainer):
|
550
|
+
"""
|
551
|
+
Trainer for Stable Diffusion models.
|
552
|
+
|
553
|
+
Supports:
|
554
|
+
- DreamBooth training
|
555
|
+
- LoRA training
|
556
|
+
- Textual Inversion
|
557
|
+
- Custom dataset training
|
558
|
+
"""
|
559
|
+
|
560
|
+
def __init__(self, config: TrainingConfig):
|
561
|
+
super().__init__(config)
|
562
|
+
|
563
|
+
if not DIFFUSERS_AVAILABLE:
|
564
|
+
raise ImportError("Diffusers not available. Install with: pip install diffusers")
|
565
|
+
|
566
|
+
self.unet = None
|
567
|
+
self.vae = None
|
568
|
+
self.text_encoder = None
|
569
|
+
self.scheduler = None
|
570
|
+
|
571
|
+
def load_model(self) -> None:
|
572
|
+
"""Load Stable Diffusion model components."""
|
573
|
+
logger.info(f"Loading Stable Diffusion model: {self.config.model_name}")
|
574
|
+
|
575
|
+
# Load pipeline
|
576
|
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
577
|
+
self.config.model_name,
|
578
|
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
579
|
+
)
|
580
|
+
|
581
|
+
self.unet = pipeline.unet
|
582
|
+
self.vae = pipeline.vae
|
583
|
+
self.text_encoder = pipeline.text_encoder
|
584
|
+
self.tokenizer = pipeline.tokenizer
|
585
|
+
self.scheduler = pipeline.scheduler
|
586
|
+
|
587
|
+
logger.info("Stable Diffusion model loaded successfully")
|
588
|
+
|
589
|
+
def prepare_dataset(self) -> None:
|
590
|
+
"""Prepare image dataset for training."""
|
591
|
+
# Implementation for image dataset preparation
|
592
|
+
logger.info("Preparing image dataset...")
|
593
|
+
# This would involve loading images, captions, and preprocessing
|
594
|
+
pass
|
595
|
+
|
596
|
+
def setup_training(self) -> None:
|
597
|
+
"""Setup training for Stable Diffusion."""
|
598
|
+
logger.info("Setting up Stable Diffusion training...")
|
599
|
+
# Implementation for SD training setup
|
600
|
+
pass
|
601
|
+
|
602
|
+
def train(self) -> str:
|
603
|
+
"""Execute Stable Diffusion training."""
|
604
|
+
logger.info("Starting Stable Diffusion training...")
|
605
|
+
|
606
|
+
# Validate configuration
|
607
|
+
issues = self.validate_config()
|
608
|
+
if issues:
|
609
|
+
raise ValueError(f"Configuration issues: {issues}")
|
610
|
+
|
611
|
+
# Implementation for SD training loop
|
612
|
+
output_path = os.path.join(self.config.output_dir, "trained_model")
|
613
|
+
|
614
|
+
logger.info(f"Stable Diffusion training completed! Model saved to: {output_path}")
|
615
|
+
return output_path
|
616
|
+
|
617
|
+
def save_model(self, output_path: str) -> None:
|
618
|
+
"""Save trained Stable Diffusion model."""
|
619
|
+
os.makedirs(output_path, exist_ok=True)
|
620
|
+
# Implementation for saving SD model
|
621
|
+
logger.info(f"Stable Diffusion model saved to: {output_path}")
|
622
|
+
|
623
|
+
|
624
|
+
class MLTrainer(BaseTrainer):
|
625
|
+
"""
|
626
|
+
Trainer for traditional ML models.
|
627
|
+
|
628
|
+
Supports:
|
629
|
+
- Scikit-learn models
|
630
|
+
- XGBoost/LightGBM
|
631
|
+
- Custom ML pipelines
|
632
|
+
"""
|
633
|
+
|
634
|
+
def __init__(self, config: TrainingConfig):
|
635
|
+
super().__init__(config)
|
636
|
+
|
637
|
+
if not SKLEARN_AVAILABLE:
|
638
|
+
raise ImportError("Scikit-learn not available. Install with: pip install scikit-learn xgboost")
|
639
|
+
|
640
|
+
self.ml_model = None
|
641
|
+
self.X_train = None
|
642
|
+
self.y_train = None
|
643
|
+
self.X_val = None
|
644
|
+
self.y_val = None
|
645
|
+
|
646
|
+
def load_model(self) -> None:
|
647
|
+
"""Initialize ML model."""
|
648
|
+
logger.info(f"Initializing ML model: {self.config.model_name}")
|
649
|
+
|
650
|
+
# Model factory based on model_name
|
651
|
+
if "xgboost" in self.config.model_name.lower():
|
652
|
+
self.ml_model = xgb.XGBClassifier()
|
653
|
+
elif "random_forest" in self.config.model_name.lower():
|
654
|
+
from sklearn.ensemble import RandomForestClassifier
|
655
|
+
self.ml_model = RandomForestClassifier()
|
656
|
+
else:
|
657
|
+
raise ValueError(f"ML model type not supported: {self.config.model_name}")
|
658
|
+
|
659
|
+
logger.info("ML model initialized successfully")
|
660
|
+
|
661
|
+
def prepare_dataset(self) -> None:
|
662
|
+
"""Prepare tabular dataset for ML training."""
|
663
|
+
logger.info("Preparing ML dataset...")
|
664
|
+
# Implementation for loading and preprocessing tabular data
|
665
|
+
pass
|
666
|
+
|
667
|
+
def setup_training(self) -> None:
|
668
|
+
"""Setup ML training parameters."""
|
669
|
+
logger.info("Setting up ML training...")
|
670
|
+
# Set hyperparameters based on config
|
671
|
+
pass
|
672
|
+
|
673
|
+
def train(self) -> str:
|
674
|
+
"""Execute ML model training."""
|
675
|
+
logger.info("Starting ML training...")
|
676
|
+
|
677
|
+
# Validate configuration
|
678
|
+
issues = self.validate_config()
|
679
|
+
if issues:
|
680
|
+
raise ValueError(f"Configuration issues: {issues}")
|
681
|
+
|
682
|
+
# Implementation for ML training
|
683
|
+
output_path = os.path.join(self.config.output_dir, "trained_model.pkl")
|
684
|
+
|
685
|
+
logger.info(f"ML training completed! Model saved to: {output_path}")
|
686
|
+
return output_path
|
687
|
+
|
688
|
+
def save_model(self, output_path: str) -> None:
|
689
|
+
"""Save trained ML model."""
|
690
|
+
import joblib
|
691
|
+
joblib.dump(self.ml_model, output_path)
|
692
|
+
logger.info(f"ML model saved to: {output_path}")
|
693
|
+
|
694
|
+
|
695
|
+
# Legacy alias for backward compatibility
|
696
|
+
SFTTrainer = LLMTrainer
|
697
|
+
|
698
|
+
|
699
|
+
def create_trainer(config: TrainingConfig) -> BaseTrainer:
|
700
|
+
"""
|
701
|
+
Factory function to create appropriate trainer based on model type.
|
702
|
+
|
703
|
+
Args:
|
704
|
+
config: Training configuration
|
705
|
+
|
706
|
+
Returns:
|
707
|
+
Appropriate trainer instance
|
708
|
+
"""
|
709
|
+
model_name = config.model_name.lower()
|
710
|
+
|
711
|
+
# Determine trainer type based on model name or training type
|
712
|
+
if any(keyword in model_name for keyword in ['stable-diffusion', 'sd-', 'diffusion']):
|
713
|
+
return StableDiffusionTrainer(config)
|
714
|
+
elif any(keyword in model_name for keyword in ['xgboost', 'random_forest', 'svm', 'linear']):
|
715
|
+
return MLTrainer(config)
|
716
|
+
elif config.training_type in ['sft', 'instruction', 'chat', 'classification']:
|
717
|
+
return LLMTrainer(config)
|
718
|
+
else:
|
719
|
+
# Default to LLM trainer for language models
|
720
|
+
return LLMTrainer(config)
|