isa-model 0.1.1__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/storage/hf_storage.py +419 -0
  3. isa_model/deployment/__init__.py +52 -0
  4. isa_model/deployment/core/__init__.py +34 -0
  5. isa_model/deployment/core/deployment_config.py +356 -0
  6. isa_model/deployment/core/deployment_manager.py +549 -0
  7. isa_model/deployment/core/isa_deployment_service.py +401 -0
  8. isa_model/eval/factory.py +381 -140
  9. isa_model/inference/ai_factory.py +142 -240
  10. isa_model/inference/providers/ml_provider.py +50 -0
  11. isa_model/inference/services/audio/openai_tts_service.py +104 -3
  12. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  13. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  14. isa_model/inference/services/llm/__init__.py +2 -0
  15. isa_model/inference/services/llm/base_llm_service.py +111 -1
  16. isa_model/inference/services/llm/ollama_llm_service.py +234 -26
  17. isa_model/inference/services/llm/openai_llm_service.py +225 -28
  18. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  19. isa_model/inference/services/ml/base_ml_service.py +78 -0
  20. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  21. isa_model/inference/services/vision/__init__.py +3 -3
  22. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  23. isa_model/inference/services/vision/base_vision_service.py +177 -0
  24. isa_model/inference/services/vision/ollama_vision_service.py +143 -17
  25. isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
  26. isa_model/training/__init__.py +62 -32
  27. isa_model/training/cloud/__init__.py +22 -0
  28. isa_model/training/cloud/job_orchestrator.py +402 -0
  29. isa_model/training/cloud/runpod_trainer.py +454 -0
  30. isa_model/training/cloud/storage_manager.py +482 -0
  31. isa_model/training/core/__init__.py +23 -0
  32. isa_model/training/core/config.py +181 -0
  33. isa_model/training/core/dataset.py +222 -0
  34. isa_model/training/core/trainer.py +720 -0
  35. isa_model/training/core/utils.py +213 -0
  36. isa_model/training/factory.py +229 -198
  37. isa_model-0.2.8.dist-info/METADATA +465 -0
  38. isa_model-0.2.8.dist-info/RECORD +86 -0
  39. isa_model/core/model_router.py +0 -226
  40. isa_model/core/model_version.py +0 -0
  41. isa_model/core/resource_manager.py +0 -202
  42. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  43. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  44. isa_model/training/engine/llama_factory/__init__.py +0 -39
  45. isa_model/training/engine/llama_factory/config.py +0 -115
  46. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  47. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  48. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  49. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  50. isa_model/training/engine/llama_factory/factory.py +0 -331
  51. isa_model/training/engine/llama_factory/rl.py +0 -254
  52. isa_model/training/engine/llama_factory/trainer.py +0 -171
  53. isa_model/training/image_model/configs/create_config.py +0 -37
  54. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  55. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  56. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  57. isa_model/training/image_model/prepare_upload.py +0 -17
  58. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  59. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  60. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  61. isa_model/training/image_model/train/train.py +0 -42
  62. isa_model/training/image_model/train/train_flux.py +0 -41
  63. isa_model/training/image_model/train/train_lora.py +0 -57
  64. isa_model/training/image_model/train_main.py +0 -25
  65. isa_model-0.1.1.dist-info/METADATA +0 -327
  66. isa_model-0.1.1.dist-info/RECORD +0 -92
  67. isa_model-0.1.1.dist-info/licenses/LICENSE +0 -21
  68. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  69. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  70. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  71. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  72. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  73. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  74. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  75. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  76. {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/WHEEL +0 -0
  77. {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,720 @@
1
+ """
2
+ Enhanced Multi-Modal Training Framework for ISA Model SDK
3
+
4
+ Supports training for:
5
+ - LLM models (GPT, Gemma, Llama, etc.) with Unsloth acceleration
6
+ - Stable Diffusion models
7
+ - Traditional ML models (scikit-learn, XGBoost, etc.)
8
+ - Computer Vision models (CNN, Vision Transformers)
9
+ - Audio models (Whisper, etc.)
10
+ """
11
+
12
+ import os
13
+ import json
14
+ import logging
15
+ from abc import ABC, abstractmethod
16
+ from typing import Optional, Dict, Any, List, Union, Tuple
17
+ from pathlib import Path
18
+ import datetime
19
+
20
+ try:
21
+ import torch
22
+ import torch.nn as nn
23
+ from transformers import (
24
+ AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification,
25
+ Trainer, TrainingArguments, DataCollatorForLanguageModeling
26
+ )
27
+ from peft import LoraConfig, get_peft_model, TaskType
28
+ from datasets import Dataset
29
+ HF_AVAILABLE = True
30
+ except ImportError:
31
+ HF_AVAILABLE = False
32
+
33
+ try:
34
+ from unsloth import FastLanguageModel
35
+ from unsloth.trainer import UnslothTrainer
36
+ UNSLOTH_AVAILABLE = True
37
+ except ImportError:
38
+ UNSLOTH_AVAILABLE = False
39
+
40
+ try:
41
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel
42
+ from diffusers.training_utils import EMAModel
43
+ DIFFUSERS_AVAILABLE = True
44
+ except ImportError:
45
+ DIFFUSERS_AVAILABLE = False
46
+
47
+ try:
48
+ import sklearn
49
+ from sklearn.base import BaseEstimator
50
+ import xgboost as xgb
51
+ SKLEARN_AVAILABLE = True
52
+ except ImportError:
53
+ SKLEARN_AVAILABLE = False
54
+
55
+ from .config import TrainingConfig, LoRAConfig, DatasetConfig
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+ # Unsloth supported models
60
+ UNSLOTH_SUPPORTED_MODELS = [
61
+ "google/gemma-2-2b",
62
+ "google/gemma-2-2b-it",
63
+ "google/gemma-2-4b",
64
+ "google/gemma-2-4b-it",
65
+ "google/gemma-2-7b",
66
+ "google/gemma-2-7b-it",
67
+ "meta-llama/Llama-2-7b-hf",
68
+ "meta-llama/Llama-2-7b-chat-hf",
69
+ "meta-llama/Llama-2-13b-hf",
70
+ "meta-llama/Llama-2-13b-chat-hf",
71
+ "mistralai/Mistral-7B-v0.1",
72
+ "mistralai/Mistral-7B-Instruct-v0.1",
73
+ "microsoft/DialoGPT-medium",
74
+ "microsoft/DialoGPT-large",
75
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
76
+ ]
77
+
78
+
79
+ class BaseTrainer(ABC):
80
+ """
81
+ Abstract base class for all trainers in the ISA Model SDK.
82
+
83
+ This class defines the common interface that all trainers must implement,
84
+ regardless of the model type (LLM, Stable Diffusion, ML, etc.).
85
+ """
86
+
87
+ def __init__(self, config: TrainingConfig):
88
+ """
89
+ Initialize the base trainer.
90
+
91
+ Args:
92
+ config: Training configuration object
93
+ """
94
+ self.config = config
95
+ self.model = None
96
+ self.tokenizer = None
97
+ self.dataset = None
98
+ self.training_args = None
99
+
100
+ # Create output directory
101
+ os.makedirs(config.output_dir, exist_ok=True)
102
+
103
+ # Setup comprehensive logging
104
+ self._setup_logging()
105
+
106
+ logger.info(f"Initialized {self.__class__.__name__} with config: {config.model_name}")
107
+ logger.info(f"Training configuration: {config.to_dict()}")
108
+
109
+ def _setup_logging(self):
110
+ """Setup comprehensive logging for training process"""
111
+ log_dir = Path(self.config.output_dir) / "logs"
112
+ log_dir.mkdir(exist_ok=True)
113
+
114
+ # Create formatters
115
+ detailed_formatter = logging.Formatter(
116
+ '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
117
+ )
118
+
119
+ # File handler for detailed logs
120
+ file_handler = logging.FileHandler(log_dir / 'training_detailed.log')
121
+ file_handler.setLevel(logging.DEBUG)
122
+ file_handler.setFormatter(detailed_formatter)
123
+
124
+ # File handler for errors only
125
+ error_handler = logging.FileHandler(log_dir / 'training_errors.log')
126
+ error_handler.setLevel(logging.ERROR)
127
+ error_handler.setFormatter(detailed_formatter)
128
+
129
+ # Console handler for important info
130
+ console_handler = logging.StreamHandler()
131
+ console_handler.setLevel(logging.INFO)
132
+ console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
133
+
134
+ # Configure logger
135
+ logger.addHandler(file_handler)
136
+ logger.addHandler(error_handler)
137
+ logger.addHandler(console_handler)
138
+ logger.setLevel(logging.DEBUG)
139
+
140
+ @abstractmethod
141
+ def load_model(self) -> None:
142
+ """Load the model and tokenizer."""
143
+ pass
144
+
145
+ @abstractmethod
146
+ def prepare_dataset(self) -> None:
147
+ """Prepare the training dataset."""
148
+ pass
149
+
150
+ @abstractmethod
151
+ def setup_training(self) -> None:
152
+ """Setup training arguments and trainer."""
153
+ pass
154
+
155
+ @abstractmethod
156
+ def train(self) -> str:
157
+ """Execute the training process."""
158
+ pass
159
+
160
+ @abstractmethod
161
+ def save_model(self, output_path: str) -> None:
162
+ """Save the trained model."""
163
+ pass
164
+
165
+ def validate_config(self) -> List[str]:
166
+ """Validate the training configuration."""
167
+ logger.debug("Validating training configuration...")
168
+ issues = []
169
+
170
+ if not self.config.model_name:
171
+ issues.append("model_name is required")
172
+
173
+ if not self.config.output_dir:
174
+ issues.append("output_dir is required")
175
+
176
+ if self.config.num_epochs <= 0:
177
+ issues.append("num_epochs must be positive")
178
+
179
+ if self.config.batch_size <= 0:
180
+ issues.append("batch_size must be positive")
181
+
182
+ if issues:
183
+ logger.error(f"Configuration validation failed: {issues}")
184
+ else:
185
+ logger.info("Configuration validation passed")
186
+
187
+ return issues
188
+
189
+ def save_training_config(self) -> None:
190
+ """Save the training configuration to output directory."""
191
+ config_path = os.path.join(self.config.output_dir, "training_config.json")
192
+ with open(config_path, 'w') as f:
193
+ json.dump(self.config.to_dict(), f, indent=2)
194
+ logger.info(f"Training config saved to: {config_path}")
195
+
196
+
197
+ class LLMTrainer(BaseTrainer):
198
+ """
199
+ Trainer for Large Language Models using HuggingFace Transformers with Unsloth acceleration.
200
+
201
+ Supports:
202
+ - Supervised Fine-Tuning (SFT)
203
+ - LoRA (Low-Rank Adaptation)
204
+ - Unsloth acceleration (2x faster, 50% less memory)
205
+ - Full parameter training
206
+ - Instruction tuning
207
+ """
208
+
209
+ def __init__(self, config: TrainingConfig):
210
+ super().__init__(config)
211
+
212
+ if not HF_AVAILABLE:
213
+ raise ImportError("HuggingFace transformers not available. Install with: pip install transformers")
214
+
215
+ self.trainer = None
216
+ self.data_collator = None
217
+ self.use_unsloth = self._should_use_unsloth()
218
+
219
+ logger.info(f"LLM Trainer initialized - Unsloth: {'✅ Enabled' if self.use_unsloth else '❌ Disabled'}")
220
+ if self.use_unsloth and not UNSLOTH_AVAILABLE:
221
+ logger.warning("Unsloth requested but not available. Install with: pip install unsloth")
222
+ self.use_unsloth = False
223
+
224
+ def _should_use_unsloth(self) -> bool:
225
+ """Determine if Unsloth should be used for this model"""
226
+ if not UNSLOTH_AVAILABLE:
227
+ return False
228
+
229
+ # Check if model is supported by Unsloth
230
+ model_name = self.config.model_name.lower()
231
+ for supported_model in UNSLOTH_SUPPORTED_MODELS:
232
+ if supported_model.lower() in model_name or model_name in supported_model.lower():
233
+ logger.info(f"Model {self.config.model_name} is supported by Unsloth")
234
+ return True
235
+
236
+ logger.info(f"Model {self.config.model_name} not in Unsloth supported list, using standard training")
237
+ return False
238
+
239
+ def load_model(self) -> None:
240
+ """Load the LLM model and tokenizer with optional Unsloth acceleration."""
241
+ logger.info(f"Loading model: {self.config.model_name}")
242
+ logger.debug(f"Using Unsloth: {self.use_unsloth}")
243
+
244
+ try:
245
+ if self.use_unsloth:
246
+ self._load_model_with_unsloth()
247
+ else:
248
+ self._load_model_standard()
249
+
250
+ logger.info("Model and tokenizer loaded successfully")
251
+
252
+ except Exception as e:
253
+ logger.error(f"Failed to load model: {e}")
254
+ raise
255
+
256
+ def _load_model_with_unsloth(self) -> None:
257
+ """Load model using Unsloth for acceleration"""
258
+ logger.info("Loading model with Unsloth acceleration...")
259
+
260
+ # Unsloth model loading
261
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
262
+ model_name=self.config.model_name,
263
+ max_seq_length=self.config.dataset_config.max_length if self.config.dataset_config else 1024,
264
+ dtype=None, # Auto-detect
265
+ load_in_4bit=True, # Use 4-bit quantization for memory efficiency
266
+ )
267
+
268
+ # Setup LoRA with Unsloth
269
+ if self.config.lora_config and self.config.lora_config.use_lora:
270
+ logger.info("Setting up LoRA with Unsloth...")
271
+ lora_config = self.config.lora_config
272
+ self.model = FastLanguageModel.get_peft_model(
273
+ self.model,
274
+ r=lora_config.lora_rank,
275
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
276
+ lora_alpha=lora_config.lora_alpha,
277
+ lora_dropout=lora_config.lora_dropout,
278
+ bias="none",
279
+ use_gradient_checkpointing="unsloth", # Unsloth's optimized gradient checkpointing
280
+ random_state=3407,
281
+ use_rslora=False, # Rank stabilized LoRA
282
+ loftq_config=None, # LoftQ
283
+ )
284
+
285
+ logger.info("Unsloth model loaded successfully")
286
+
287
+ def _load_model_standard(self) -> None:
288
+ """Load model using standard HuggingFace transformers"""
289
+ logger.info("Loading model with standard HuggingFace transformers...")
290
+
291
+ # Load tokenizer
292
+ self.tokenizer = AutoTokenizer.from_pretrained(
293
+ self.config.model_name,
294
+ trust_remote_code=True,
295
+ padding_side="right"
296
+ )
297
+
298
+ # Add pad token if missing
299
+ if self.tokenizer.pad_token is None:
300
+ self.tokenizer.pad_token = self.tokenizer.eos_token
301
+ logger.debug("Added pad token to tokenizer")
302
+
303
+ # Load model
304
+ model_kwargs = {
305
+ "trust_remote_code": True,
306
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
307
+ "device_map": "auto" if torch.cuda.is_available() else None
308
+ }
309
+
310
+ logger.debug(f"Model loading kwargs: {model_kwargs}")
311
+
312
+ if self.config.training_type == "classification":
313
+ self.model = AutoModelForSequenceClassification.from_pretrained(
314
+ self.config.model_name, **model_kwargs
315
+ )
316
+ else:
317
+ self.model = AutoModelForCausalLM.from_pretrained(
318
+ self.config.model_name, **model_kwargs
319
+ )
320
+
321
+ # Setup LoRA if enabled
322
+ if self.config.lora_config and self.config.lora_config.use_lora:
323
+ self._setup_lora()
324
+
325
+ logger.info("Standard model loaded successfully")
326
+
327
+ def _setup_lora(self) -> None:
328
+ """Setup LoRA configuration for standard training"""
329
+ logger.info("Setting up LoRA configuration...")
330
+
331
+ lora_config = LoraConfig(
332
+ r=self.config.lora_config.lora_rank,
333
+ lora_alpha=self.config.lora_config.lora_alpha,
334
+ target_modules=self.config.lora_config.lora_target_modules,
335
+ lora_dropout=self.config.lora_config.lora_dropout,
336
+ bias="none",
337
+ task_type=TaskType.CAUSAL_LM if self.config.training_type != "classification" else TaskType.SEQ_CLS
338
+ )
339
+
340
+ self.model = get_peft_model(self.model, lora_config)
341
+ self.model.print_trainable_parameters()
342
+ logger.info("LoRA configuration applied successfully")
343
+
344
+ def prepare_dataset(self) -> None:
345
+ """Prepare the training dataset."""
346
+ logger.info("Preparing training dataset...")
347
+
348
+ try:
349
+ from .dataset import DatasetManager
350
+
351
+ if not self.config.dataset_config:
352
+ raise ValueError("Dataset configuration is required")
353
+
354
+ dataset_manager = DatasetManager(
355
+ self.tokenizer,
356
+ max_length=self.config.dataset_config.max_length
357
+ )
358
+
359
+ train_dataset, eval_dataset = dataset_manager.prepare_dataset(
360
+ dataset_path=self.config.dataset_config.dataset_path,
361
+ dataset_format=self.config.dataset_config.dataset_format,
362
+ validation_split=self.config.dataset_config.validation_split
363
+ )
364
+
365
+ self.dataset = {
366
+ 'train': train_dataset,
367
+ 'validation': eval_dataset
368
+ }
369
+
370
+ # Setup data collator
371
+ if self.config.training_type == "classification":
372
+ self.data_collator = None # Use default
373
+ else:
374
+ self.data_collator = DataCollatorForLanguageModeling(
375
+ tokenizer=self.tokenizer,
376
+ mlm=False
377
+ )
378
+
379
+ logger.info(f"Dataset prepared - Train: {len(train_dataset)} samples")
380
+ if eval_dataset:
381
+ logger.info(f"Validation: {len(eval_dataset)} samples")
382
+
383
+ except Exception as e:
384
+ logger.error(f"Failed to prepare dataset: {e}")
385
+ raise
386
+
387
+ def setup_training(self) -> None:
388
+ """Setup training arguments and trainer."""
389
+ logger.info("Setting up training configuration...")
390
+
391
+ try:
392
+ # Calculate training steps
393
+ total_steps = len(self.dataset['train']) // (self.config.batch_size * self.config.gradient_accumulation_steps) * self.config.num_epochs
394
+
395
+ logger.debug(f"Total training steps: {total_steps}")
396
+
397
+ self.training_args = TrainingArguments(
398
+ output_dir=self.config.output_dir,
399
+ num_train_epochs=self.config.num_epochs,
400
+ per_device_train_batch_size=self.config.batch_size,
401
+ per_device_eval_batch_size=self.config.batch_size,
402
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
403
+ learning_rate=self.config.learning_rate,
404
+ weight_decay=self.config.weight_decay,
405
+ warmup_steps=max(1, int(0.1 * total_steps)), # 10% warmup
406
+ logging_steps=max(1, total_steps // 100), # Log 100 times per training
407
+ eval_strategy="steps" if self.dataset.get('validation') else "no",
408
+ eval_steps=max(1, total_steps // 10) if self.dataset.get('validation') else None,
409
+ save_strategy="steps",
410
+ save_steps=max(1, total_steps // 5), # Save 5 times per training
411
+ save_total_limit=3,
412
+ load_best_model_at_end=True if self.dataset.get('validation') else False,
413
+ metric_for_best_model="eval_loss" if self.dataset.get('validation') else None,
414
+ greater_is_better=False,
415
+ report_to=None, # Disable wandb/tensorboard by default
416
+ remove_unused_columns=False,
417
+ dataloader_pin_memory=False,
418
+ fp16=torch.cuda.is_available() and not self.use_unsloth, # Unsloth handles precision
419
+ gradient_checkpointing=True and not self.use_unsloth, # Unsloth handles checkpointing
420
+ optim="adamw_torch",
421
+ lr_scheduler_type="cosine",
422
+ logging_dir=os.path.join(self.config.output_dir, "logs"),
423
+ run_name=f"training_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
424
+ )
425
+
426
+ # Initialize trainer
427
+ if self.use_unsloth:
428
+ logger.info("Initializing Unsloth trainer...")
429
+ self.trainer = UnslothTrainer(
430
+ model=self.model,
431
+ tokenizer=self.tokenizer,
432
+ train_dataset=self.dataset['train'],
433
+ eval_dataset=self.dataset.get('validation'),
434
+ args=self.training_args,
435
+ data_collator=self.data_collator,
436
+ )
437
+ else:
438
+ logger.info("Initializing standard trainer...")
439
+ self.trainer = Trainer(
440
+ model=self.model,
441
+ args=self.training_args,
442
+ train_dataset=self.dataset['train'],
443
+ eval_dataset=self.dataset.get('validation'),
444
+ tokenizer=self.tokenizer,
445
+ data_collator=self.data_collator
446
+ )
447
+
448
+ logger.info("Training setup completed successfully")
449
+
450
+ except Exception as e:
451
+ logger.error(f"Failed to setup training: {e}")
452
+ raise
453
+
454
+ def train(self) -> str:
455
+ """Execute the training process."""
456
+ logger.info("=" * 60)
457
+ logger.info("STARTING LLM TRAINING")
458
+ logger.info("=" * 60)
459
+
460
+ try:
461
+ # Validate configuration
462
+ issues = self.validate_config()
463
+ if issues:
464
+ raise ValueError(f"Configuration issues: {issues}")
465
+
466
+ # Load model and prepare dataset
467
+ logger.info("Step 1/5: Loading model...")
468
+ self.load_model()
469
+
470
+ logger.info("Step 2/5: Preparing dataset...")
471
+ self.prepare_dataset()
472
+
473
+ logger.info("Step 3/5: Setting up training...")
474
+ self.setup_training()
475
+
476
+ # Save training config
477
+ self.save_training_config()
478
+
479
+ logger.info("Step 4/5: Starting training...")
480
+ logger.info(f"Training with {'Unsloth acceleration' if self.use_unsloth else 'standard HuggingFace'}")
481
+
482
+ # Start training
483
+ train_result = self.trainer.train()
484
+
485
+ logger.info("Step 5/5: Saving model...")
486
+ # Save final model
487
+ final_model_path = os.path.join(self.config.output_dir, "final_model")
488
+ self.save_model(final_model_path)
489
+
490
+ # Save training metrics
491
+ metrics_path = os.path.join(self.config.output_dir, "training_metrics.json")
492
+ with open(metrics_path, 'w') as f:
493
+ json.dump(train_result.metrics, f, indent=2)
494
+
495
+ logger.info("=" * 60)
496
+ logger.info("TRAINING COMPLETED SUCCESSFULLY!")
497
+ logger.info("=" * 60)
498
+ logger.info(f"Model saved to: {final_model_path}")
499
+ logger.info(f"Training metrics saved to: {metrics_path}")
500
+
501
+ return final_model_path
502
+
503
+ except Exception as e:
504
+ logger.error("=" * 60)
505
+ logger.error("TRAINING FAILED!")
506
+ logger.error("=" * 60)
507
+ logger.error(f"Error: {e}")
508
+ logger.error("Check the error logs for detailed information")
509
+ raise
510
+
511
+ def save_model(self, output_path: str) -> None:
512
+ """Save the trained model."""
513
+ logger.info(f"Saving model to: {output_path}")
514
+
515
+ try:
516
+ os.makedirs(output_path, exist_ok=True)
517
+
518
+ # Save model and tokenizer
519
+ self.trainer.save_model(output_path)
520
+ self.tokenizer.save_pretrained(output_path)
521
+
522
+ # Save LoRA adapters if used
523
+ if self.config.lora_config and self.config.lora_config.use_lora:
524
+ adapter_path = os.path.join(output_path, "adapter_model")
525
+ if hasattr(self.model, 'save_pretrained'):
526
+ self.model.save_pretrained(adapter_path)
527
+ logger.info(f"LoRA adapters saved to: {adapter_path}")
528
+
529
+ # Save additional metadata
530
+ metadata = {
531
+ "model_name": self.config.model_name,
532
+ "training_type": self.config.training_type,
533
+ "use_unsloth": self.use_unsloth,
534
+ "use_lora": self.config.lora_config.use_lora if self.config.lora_config else False,
535
+ "saved_at": datetime.datetime.now().isoformat(),
536
+ "config": self.config.to_dict()
537
+ }
538
+
539
+ with open(os.path.join(output_path, "training_metadata.json"), 'w') as f:
540
+ json.dump(metadata, f, indent=2)
541
+
542
+ logger.info(f"Model saved successfully to: {output_path}")
543
+
544
+ except Exception as e:
545
+ logger.error(f"Failed to save model: {e}")
546
+ raise
547
+
548
+
549
+ class StableDiffusionTrainer(BaseTrainer):
550
+ """
551
+ Trainer for Stable Diffusion models.
552
+
553
+ Supports:
554
+ - DreamBooth training
555
+ - LoRA training
556
+ - Textual Inversion
557
+ - Custom dataset training
558
+ """
559
+
560
+ def __init__(self, config: TrainingConfig):
561
+ super().__init__(config)
562
+
563
+ if not DIFFUSERS_AVAILABLE:
564
+ raise ImportError("Diffusers not available. Install with: pip install diffusers")
565
+
566
+ self.unet = None
567
+ self.vae = None
568
+ self.text_encoder = None
569
+ self.scheduler = None
570
+
571
+ def load_model(self) -> None:
572
+ """Load Stable Diffusion model components."""
573
+ logger.info(f"Loading Stable Diffusion model: {self.config.model_name}")
574
+
575
+ # Load pipeline
576
+ pipeline = StableDiffusionPipeline.from_pretrained(
577
+ self.config.model_name,
578
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
579
+ )
580
+
581
+ self.unet = pipeline.unet
582
+ self.vae = pipeline.vae
583
+ self.text_encoder = pipeline.text_encoder
584
+ self.tokenizer = pipeline.tokenizer
585
+ self.scheduler = pipeline.scheduler
586
+
587
+ logger.info("Stable Diffusion model loaded successfully")
588
+
589
+ def prepare_dataset(self) -> None:
590
+ """Prepare image dataset for training."""
591
+ # Implementation for image dataset preparation
592
+ logger.info("Preparing image dataset...")
593
+ # This would involve loading images, captions, and preprocessing
594
+ pass
595
+
596
+ def setup_training(self) -> None:
597
+ """Setup training for Stable Diffusion."""
598
+ logger.info("Setting up Stable Diffusion training...")
599
+ # Implementation for SD training setup
600
+ pass
601
+
602
+ def train(self) -> str:
603
+ """Execute Stable Diffusion training."""
604
+ logger.info("Starting Stable Diffusion training...")
605
+
606
+ # Validate configuration
607
+ issues = self.validate_config()
608
+ if issues:
609
+ raise ValueError(f"Configuration issues: {issues}")
610
+
611
+ # Implementation for SD training loop
612
+ output_path = os.path.join(self.config.output_dir, "trained_model")
613
+
614
+ logger.info(f"Stable Diffusion training completed! Model saved to: {output_path}")
615
+ return output_path
616
+
617
+ def save_model(self, output_path: str) -> None:
618
+ """Save trained Stable Diffusion model."""
619
+ os.makedirs(output_path, exist_ok=True)
620
+ # Implementation for saving SD model
621
+ logger.info(f"Stable Diffusion model saved to: {output_path}")
622
+
623
+
624
+ class MLTrainer(BaseTrainer):
625
+ """
626
+ Trainer for traditional ML models.
627
+
628
+ Supports:
629
+ - Scikit-learn models
630
+ - XGBoost/LightGBM
631
+ - Custom ML pipelines
632
+ """
633
+
634
+ def __init__(self, config: TrainingConfig):
635
+ super().__init__(config)
636
+
637
+ if not SKLEARN_AVAILABLE:
638
+ raise ImportError("Scikit-learn not available. Install with: pip install scikit-learn xgboost")
639
+
640
+ self.ml_model = None
641
+ self.X_train = None
642
+ self.y_train = None
643
+ self.X_val = None
644
+ self.y_val = None
645
+
646
+ def load_model(self) -> None:
647
+ """Initialize ML model."""
648
+ logger.info(f"Initializing ML model: {self.config.model_name}")
649
+
650
+ # Model factory based on model_name
651
+ if "xgboost" in self.config.model_name.lower():
652
+ self.ml_model = xgb.XGBClassifier()
653
+ elif "random_forest" in self.config.model_name.lower():
654
+ from sklearn.ensemble import RandomForestClassifier
655
+ self.ml_model = RandomForestClassifier()
656
+ else:
657
+ raise ValueError(f"ML model type not supported: {self.config.model_name}")
658
+
659
+ logger.info("ML model initialized successfully")
660
+
661
+ def prepare_dataset(self) -> None:
662
+ """Prepare tabular dataset for ML training."""
663
+ logger.info("Preparing ML dataset...")
664
+ # Implementation for loading and preprocessing tabular data
665
+ pass
666
+
667
+ def setup_training(self) -> None:
668
+ """Setup ML training parameters."""
669
+ logger.info("Setting up ML training...")
670
+ # Set hyperparameters based on config
671
+ pass
672
+
673
+ def train(self) -> str:
674
+ """Execute ML model training."""
675
+ logger.info("Starting ML training...")
676
+
677
+ # Validate configuration
678
+ issues = self.validate_config()
679
+ if issues:
680
+ raise ValueError(f"Configuration issues: {issues}")
681
+
682
+ # Implementation for ML training
683
+ output_path = os.path.join(self.config.output_dir, "trained_model.pkl")
684
+
685
+ logger.info(f"ML training completed! Model saved to: {output_path}")
686
+ return output_path
687
+
688
+ def save_model(self, output_path: str) -> None:
689
+ """Save trained ML model."""
690
+ import joblib
691
+ joblib.dump(self.ml_model, output_path)
692
+ logger.info(f"ML model saved to: {output_path}")
693
+
694
+
695
+ # Legacy alias for backward compatibility
696
+ SFTTrainer = LLMTrainer
697
+
698
+
699
+ def create_trainer(config: TrainingConfig) -> BaseTrainer:
700
+ """
701
+ Factory function to create appropriate trainer based on model type.
702
+
703
+ Args:
704
+ config: Training configuration
705
+
706
+ Returns:
707
+ Appropriate trainer instance
708
+ """
709
+ model_name = config.model_name.lower()
710
+
711
+ # Determine trainer type based on model name or training type
712
+ if any(keyword in model_name for keyword in ['stable-diffusion', 'sd-', 'diffusion']):
713
+ return StableDiffusionTrainer(config)
714
+ elif any(keyword in model_name for keyword in ['xgboost', 'random_forest', 'svm', 'linear']):
715
+ return MLTrainer(config)
716
+ elif config.training_type in ['sft', 'instruction', 'chat', 'classification']:
717
+ return LLMTrainer(config)
718
+ else:
719
+ # Default to LLM trainer for language models
720
+ return LLMTrainer(config)