cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,957 @@
|
|
|
1
|
+
"""LoRA training implementation using MLX."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
import os
|
|
6
|
+
import math
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional, Dict, Any, Callable, Tuple
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
import json
|
|
11
|
+
import shutil
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import mlx.core as mx
|
|
15
|
+
import mlx.nn as nn
|
|
16
|
+
import mlx.optimizers as optim
|
|
17
|
+
from mlx.utils import tree_map
|
|
18
|
+
MLX_AVAILABLE = True
|
|
19
|
+
except Exception as exc: # noqa: BLE001
|
|
20
|
+
MLX_AVAILABLE = False
|
|
21
|
+
mx = nn = optim = tree_map = None # type: ignore
|
|
22
|
+
_MLX_IMPORT_ERROR = exc
|
|
23
|
+
|
|
24
|
+
# Import MLX LM functions
|
|
25
|
+
try:
|
|
26
|
+
from mlx_lm import load as mlx_load
|
|
27
|
+
from mlx_lm.tuner.lora import LoRALinear
|
|
28
|
+
from mlx_lm.tuner.trainer import TrainingArgs, train as mlx_train
|
|
29
|
+
from mlx_lm.tuner.datasets import load_dataset as mlx_load_dataset
|
|
30
|
+
except ImportError:
|
|
31
|
+
# Fallback implementations
|
|
32
|
+
mlx_load = None
|
|
33
|
+
LoRALinear = None
|
|
34
|
+
TrainingArgs = None
|
|
35
|
+
mlx_train = None
|
|
36
|
+
mlx_load_dataset = None
|
|
37
|
+
|
|
38
|
+
from cortex.model_manager import ModelManager
|
|
39
|
+
from cortex.config import Config
|
|
40
|
+
from cortex.metal.mlx_accelerator import MLXAccelerator, MLXConfig
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class TrainingConfig:
|
|
47
|
+
"""Enhanced configuration for fine-tuning with intelligent defaults."""
|
|
48
|
+
# Core training parameters
|
|
49
|
+
epochs: int = 2
|
|
50
|
+
learning_rate: float = 3e-5
|
|
51
|
+
batch_size: int = 1
|
|
52
|
+
gradient_accumulation_steps: int = 4
|
|
53
|
+
|
|
54
|
+
# LoRA parameters
|
|
55
|
+
lora_r: int = 16 # LoRA rank
|
|
56
|
+
lora_alpha: int = 32 # LoRA alpha
|
|
57
|
+
lora_dropout: float = 0.1
|
|
58
|
+
target_modules: list = None # Auto-detect if None
|
|
59
|
+
num_lora_layers: int = 16 # Number of layers to apply LoRA to
|
|
60
|
+
|
|
61
|
+
# Optimization parameters
|
|
62
|
+
optimizer_type: str = "adamw" # adamw, sgd, adafactor
|
|
63
|
+
weight_decay: float = 0.01
|
|
64
|
+
max_grad_norm: float = 1.0
|
|
65
|
+
warmup_steps: Optional[int] = None # If None, calculated from warmup_ratio
|
|
66
|
+
warmup_ratio: float = 0.1
|
|
67
|
+
lr_scheduler: str = "linear" # linear, cosine, constant, polynomial
|
|
68
|
+
|
|
69
|
+
# Memory and performance
|
|
70
|
+
gradient_checkpointing: bool = False
|
|
71
|
+
quantization_bits: Optional[int] = None # 4 or 8 bit quantization
|
|
72
|
+
dataloader_num_workers: int = 0
|
|
73
|
+
fp16: bool = True
|
|
74
|
+
bf16: bool = False
|
|
75
|
+
|
|
76
|
+
# Task-specific settings
|
|
77
|
+
task_type: str = "chat" # chat, completion, structured
|
|
78
|
+
max_sequence_length: int = 2048
|
|
79
|
+
response_template: Optional[str] = None
|
|
80
|
+
|
|
81
|
+
# Dataset settings
|
|
82
|
+
train_test_split: float = 0.0 # If > 0, split dataset for validation
|
|
83
|
+
shuffle_dataset: bool = True
|
|
84
|
+
|
|
85
|
+
# Advanced settings
|
|
86
|
+
seed: int = 42
|
|
87
|
+
logging_steps: int = 10
|
|
88
|
+
eval_steps: Optional[int] = None
|
|
89
|
+
save_steps: int = 500
|
|
90
|
+
early_stopping_patience: Optional[int] = None
|
|
91
|
+
|
|
92
|
+
# Model-aware settings (populated automatically)
|
|
93
|
+
model_size_category: str = "medium" # tiny, small, medium, large, xlarge
|
|
94
|
+
estimated_parameters_b: float = 2.0 # Estimated parameters in billions
|
|
95
|
+
auto_configured: bool = False # Whether config was auto-generated
|
|
96
|
+
configuration_source: str = "manual" # manual, smart_quick, smart_balanced, smart_quality
|
|
97
|
+
|
|
98
|
+
def __post_init__(self):
|
|
99
|
+
if self.target_modules is None:
|
|
100
|
+
# Default target modules for LoRA
|
|
101
|
+
self.target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
|
102
|
+
|
|
103
|
+
def validate(self) -> Tuple[bool, str]:
|
|
104
|
+
"""Validate configuration settings."""
|
|
105
|
+
if self.learning_rate <= 0 or self.learning_rate > 1:
|
|
106
|
+
return False, f"Invalid learning rate: {self.learning_rate}"
|
|
107
|
+
|
|
108
|
+
if self.epochs < 1 or self.epochs > 100:
|
|
109
|
+
return False, f"Invalid number of epochs: {self.epochs}"
|
|
110
|
+
|
|
111
|
+
if self.batch_size < 1 or self.batch_size > 128:
|
|
112
|
+
return False, f"Invalid batch size: {self.batch_size}"
|
|
113
|
+
|
|
114
|
+
if self.lora_r < 1 or self.lora_r > 256:
|
|
115
|
+
return False, f"Invalid LoRA rank: {self.lora_r}"
|
|
116
|
+
|
|
117
|
+
if self.quantization_bits and self.quantization_bits not in [4, 8]:
|
|
118
|
+
return False, f"Invalid quantization bits: {self.quantization_bits}. Must be 4 or 8."
|
|
119
|
+
|
|
120
|
+
return True, "Configuration is valid"
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class SmartConfigFactory:
|
|
124
|
+
"""Factory for creating intelligent training configurations based on model and data characteristics."""
|
|
125
|
+
|
|
126
|
+
# Model size categories (parameters in billions)
|
|
127
|
+
MODEL_CATEGORIES = {
|
|
128
|
+
"tiny": (0, 0.5), # < 500M parameters (e.g., DistilBERT, small GPT-2)
|
|
129
|
+
"small": (0.5, 2), # 500M-2B (e.g., GPT-2, small Llama)
|
|
130
|
+
"medium": (2, 8), # 2B-8B (e.g., Gemma-7B, Llama-2-7B)
|
|
131
|
+
"large": (8, 20), # 8B-20B (e.g., Llama-2-13B, Mistral-7B variants)
|
|
132
|
+
"xlarge": (20, float('inf')) # 20B+ (e.g., Llama-2-70B, GPT-3.5+)
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
# Optimal settings by model size category
|
|
136
|
+
CATEGORY_DEFAULTS = {
|
|
137
|
+
"tiny": {
|
|
138
|
+
"learning_rate": 5e-4, # Higher LR for small models
|
|
139
|
+
"epochs": 5, # More epochs needed
|
|
140
|
+
"lora_r": 8, # Lower rank sufficient
|
|
141
|
+
"lora_alpha": 16,
|
|
142
|
+
"batch_size": 4, # Can handle larger batches
|
|
143
|
+
"gradient_accumulation_steps": 2,
|
|
144
|
+
"warmup_ratio": 0.05, # Less warmup needed
|
|
145
|
+
"weight_decay": 0.001, # Less regularization
|
|
146
|
+
},
|
|
147
|
+
"small": {
|
|
148
|
+
"learning_rate": 3e-4,
|
|
149
|
+
"epochs": 4,
|
|
150
|
+
"lora_r": 16,
|
|
151
|
+
"lora_alpha": 32,
|
|
152
|
+
"batch_size": 2,
|
|
153
|
+
"gradient_accumulation_steps": 4,
|
|
154
|
+
"warmup_ratio": 0.1,
|
|
155
|
+
"weight_decay": 0.01,
|
|
156
|
+
},
|
|
157
|
+
"medium": {
|
|
158
|
+
"learning_rate": 1e-4, # Standard settings for most models
|
|
159
|
+
"epochs": 3,
|
|
160
|
+
"lora_r": 16,
|
|
161
|
+
"lora_alpha": 32,
|
|
162
|
+
"batch_size": 1,
|
|
163
|
+
"gradient_accumulation_steps": 8,
|
|
164
|
+
"warmup_ratio": 0.1,
|
|
165
|
+
"weight_decay": 0.01,
|
|
166
|
+
},
|
|
167
|
+
"large": {
|
|
168
|
+
"learning_rate": 5e-5, # Lower LR for stability
|
|
169
|
+
"epochs": 2,
|
|
170
|
+
"lora_r": 32, # Higher rank for complex models
|
|
171
|
+
"lora_alpha": 64,
|
|
172
|
+
"batch_size": 1,
|
|
173
|
+
"gradient_accumulation_steps": 16,
|
|
174
|
+
"warmup_ratio": 0.15, # More warmup
|
|
175
|
+
"weight_decay": 0.01,
|
|
176
|
+
},
|
|
177
|
+
"xlarge": {
|
|
178
|
+
"learning_rate": 2e-5, # Very conservative
|
|
179
|
+
"epochs": 2,
|
|
180
|
+
"lora_r": 64, # High rank for very large models
|
|
181
|
+
"lora_alpha": 128,
|
|
182
|
+
"batch_size": 1,
|
|
183
|
+
"gradient_accumulation_steps": 32,
|
|
184
|
+
"warmup_ratio": 0.2,
|
|
185
|
+
"weight_decay": 0.01,
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def categorize_model_size(cls, size_gb: float, model_manager=None, model_path=None) -> Tuple[str, float]:
|
|
191
|
+
"""Categorize model based on actual parameters if possible, fallback to size estimation."""
|
|
192
|
+
estimated_params_b = size_gb / 2.0 # Fallback estimation
|
|
193
|
+
|
|
194
|
+
# Try to get accurate parameter count if model_manager and path are provided
|
|
195
|
+
if model_manager and model_path:
|
|
196
|
+
try:
|
|
197
|
+
from pathlib import Path
|
|
198
|
+
actual_params_b = model_manager.get_model_parameters_smart(Path(model_path))
|
|
199
|
+
if actual_params_b is not None:
|
|
200
|
+
estimated_params_b = actual_params_b # Already in billions
|
|
201
|
+
logger.info(f"Using accurate parameter count: {estimated_params_b:.2f}B parameters")
|
|
202
|
+
else:
|
|
203
|
+
logger.warning(f"Could not detect parameters, using size estimation: {estimated_params_b:.2f}B")
|
|
204
|
+
except Exception as e:
|
|
205
|
+
logger.warning(f"Parameter detection failed: {e}, using size estimation")
|
|
206
|
+
|
|
207
|
+
for category, (min_params, max_params) in cls.MODEL_CATEGORIES.items():
|
|
208
|
+
if min_params <= estimated_params_b < max_params:
|
|
209
|
+
return category, estimated_params_b
|
|
210
|
+
|
|
211
|
+
# Fallback to medium if can't categorize
|
|
212
|
+
return "medium", estimated_params_b
|
|
213
|
+
|
|
214
|
+
@classmethod
|
|
215
|
+
def analyze_dataset(cls, dataset_path: Path) -> Dict[str, Any]:
|
|
216
|
+
"""Analyze dataset to inform training configuration."""
|
|
217
|
+
try:
|
|
218
|
+
examples = []
|
|
219
|
+
with open(dataset_path, 'r') as f:
|
|
220
|
+
for line in f:
|
|
221
|
+
examples.append(json.loads(line.strip()))
|
|
222
|
+
|
|
223
|
+
dataset_size = len(examples)
|
|
224
|
+
|
|
225
|
+
# Analyze content to detect task type
|
|
226
|
+
task_type = "chat" # Default
|
|
227
|
+
avg_length = 0
|
|
228
|
+
|
|
229
|
+
if examples:
|
|
230
|
+
sample = examples[0]
|
|
231
|
+
|
|
232
|
+
# Detect task type from structure
|
|
233
|
+
if 'prompt' in sample and 'response' in sample:
|
|
234
|
+
task_type = "chat"
|
|
235
|
+
elif 'prompt' in sample and 'completion' in sample:
|
|
236
|
+
task_type = "completion"
|
|
237
|
+
elif 'text' in sample:
|
|
238
|
+
task_type = "completion"
|
|
239
|
+
|
|
240
|
+
# Calculate average text length
|
|
241
|
+
total_chars = 0
|
|
242
|
+
for example in examples[:100]: # Sample first 100
|
|
243
|
+
text = ""
|
|
244
|
+
if 'text' in example:
|
|
245
|
+
text = example['text']
|
|
246
|
+
elif 'prompt' in example and 'response' in example:
|
|
247
|
+
text = example['prompt'] + example['response']
|
|
248
|
+
elif 'prompt' in example and 'completion' in example:
|
|
249
|
+
text = example['prompt'] + example['completion']
|
|
250
|
+
total_chars += len(text)
|
|
251
|
+
|
|
252
|
+
avg_length = total_chars // min(len(examples), 100)
|
|
253
|
+
|
|
254
|
+
return {
|
|
255
|
+
"size": dataset_size,
|
|
256
|
+
"task_type": task_type,
|
|
257
|
+
"avg_length": avg_length,
|
|
258
|
+
"size_category": cls._get_dataset_size_category(dataset_size)
|
|
259
|
+
}
|
|
260
|
+
except Exception as e:
|
|
261
|
+
logger.warning(f"Failed to analyze dataset: {e}")
|
|
262
|
+
return {
|
|
263
|
+
"size": 0,
|
|
264
|
+
"task_type": "chat",
|
|
265
|
+
"avg_length": 1000,
|
|
266
|
+
"size_category": "small"
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
@classmethod
|
|
270
|
+
def _get_dataset_size_category(cls, size: int) -> str:
|
|
271
|
+
"""Categorize dataset by size."""
|
|
272
|
+
if size < 50:
|
|
273
|
+
return "tiny"
|
|
274
|
+
elif size < 500:
|
|
275
|
+
return "small"
|
|
276
|
+
elif size < 2000:
|
|
277
|
+
return "medium"
|
|
278
|
+
elif size < 10000:
|
|
279
|
+
return "large"
|
|
280
|
+
else:
|
|
281
|
+
return "xlarge"
|
|
282
|
+
|
|
283
|
+
@classmethod
|
|
284
|
+
def create_smart_config(
|
|
285
|
+
cls,
|
|
286
|
+
model_size_gb: float,
|
|
287
|
+
dataset_path: Path,
|
|
288
|
+
preset: str = "balanced",
|
|
289
|
+
custom_settings: Optional[Dict[str, Any]] = None,
|
|
290
|
+
model_manager = None,
|
|
291
|
+
model_path: Optional[str] = None
|
|
292
|
+
) -> TrainingConfig:
|
|
293
|
+
"""Create an intelligent training configuration."""
|
|
294
|
+
|
|
295
|
+
# Analyze model with accurate parameter detection
|
|
296
|
+
model_category, estimated_params = cls.categorize_model_size(
|
|
297
|
+
model_size_gb, model_manager, model_path
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Analyze dataset
|
|
301
|
+
dataset_info = cls.analyze_dataset(dataset_path)
|
|
302
|
+
|
|
303
|
+
# Get base settings for model category
|
|
304
|
+
base_config = cls.CATEGORY_DEFAULTS[model_category].copy()
|
|
305
|
+
|
|
306
|
+
# Apply preset modifications
|
|
307
|
+
if preset == "quick":
|
|
308
|
+
base_config["epochs"] = max(1, base_config["epochs"] - 1)
|
|
309
|
+
base_config["learning_rate"] *= 1.5 # Faster learning
|
|
310
|
+
elif preset == "quality":
|
|
311
|
+
base_config["epochs"] += 1
|
|
312
|
+
base_config["learning_rate"] *= 0.8 # More conservative
|
|
313
|
+
base_config["lora_r"] = min(64, base_config["lora_r"] * 2) # Higher rank
|
|
314
|
+
|
|
315
|
+
# Adjust for dataset size
|
|
316
|
+
dataset_size = dataset_info["size"]
|
|
317
|
+
if dataset_size < 100: # Small dataset
|
|
318
|
+
base_config["epochs"] = min(base_config["epochs"] + 2, 8) # More epochs
|
|
319
|
+
base_config["weight_decay"] *= 0.5 # Less regularization
|
|
320
|
+
elif dataset_size > 5000: # Large dataset
|
|
321
|
+
base_config["epochs"] = max(1, base_config["epochs"] - 1) # Fewer epochs
|
|
322
|
+
|
|
323
|
+
# Adjust for sequence length
|
|
324
|
+
if dataset_info["avg_length"] > 2000:
|
|
325
|
+
base_config["gradient_accumulation_steps"] *= 2 # Handle memory
|
|
326
|
+
base_config["max_sequence_length"] = 4096
|
|
327
|
+
|
|
328
|
+
total_mem_gb = cls._get_total_memory_gb()
|
|
329
|
+
memory_guard_applied = cls._apply_memory_guards(base_config, total_mem_gb)
|
|
330
|
+
|
|
331
|
+
# Apply custom settings if provided
|
|
332
|
+
if custom_settings:
|
|
333
|
+
base_config.update(custom_settings)
|
|
334
|
+
|
|
335
|
+
# Create configuration
|
|
336
|
+
config = TrainingConfig(
|
|
337
|
+
# Core parameters
|
|
338
|
+
epochs=base_config["epochs"],
|
|
339
|
+
learning_rate=base_config["learning_rate"],
|
|
340
|
+
batch_size=base_config["batch_size"],
|
|
341
|
+
gradient_accumulation_steps=base_config["gradient_accumulation_steps"],
|
|
342
|
+
|
|
343
|
+
# LoRA parameters
|
|
344
|
+
lora_r=base_config["lora_r"],
|
|
345
|
+
lora_alpha=base_config["lora_alpha"],
|
|
346
|
+
|
|
347
|
+
# Optimization
|
|
348
|
+
weight_decay=base_config["weight_decay"],
|
|
349
|
+
warmup_ratio=base_config["warmup_ratio"],
|
|
350
|
+
|
|
351
|
+
# Task-specific
|
|
352
|
+
task_type=dataset_info["task_type"],
|
|
353
|
+
max_sequence_length=base_config.get("max_sequence_length", 2048),
|
|
354
|
+
|
|
355
|
+
# Metadata
|
|
356
|
+
model_size_category=model_category,
|
|
357
|
+
estimated_parameters_b=estimated_params,
|
|
358
|
+
auto_configured=True,
|
|
359
|
+
configuration_source=f"smart_{preset}{'_memory_guarded' if memory_guard_applied else ''}"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
return config
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def get_preset_configs(cls) -> Dict[str, Dict[str, Any]]:
|
|
366
|
+
"""Get preset configuration descriptions."""
|
|
367
|
+
return {
|
|
368
|
+
"quick": {
|
|
369
|
+
"name": "Quick",
|
|
370
|
+
"description": "Fast training with fewer epochs",
|
|
371
|
+
"use_case": "Quick experimentation and testing",
|
|
372
|
+
"time_factor": 0.7
|
|
373
|
+
},
|
|
374
|
+
"balanced": {
|
|
375
|
+
"name": "Balanced",
|
|
376
|
+
"description": "Optimal balance of speed and quality",
|
|
377
|
+
"use_case": "Most general use cases (recommended)",
|
|
378
|
+
"time_factor": 1.0
|
|
379
|
+
},
|
|
380
|
+
"quality": {
|
|
381
|
+
"name": "Quality",
|
|
382
|
+
"description": "Best results with more training",
|
|
383
|
+
"use_case": "Production models, important tasks",
|
|
384
|
+
"time_factor": 1.5
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
@classmethod
|
|
389
|
+
def generate_guidance_message(cls, config: TrainingConfig, model_name: str) -> str:
|
|
390
|
+
"""Generate helpful guidance message for the user."""
|
|
391
|
+
messages = []
|
|
392
|
+
if config.configuration_source.endswith("memory_guarded"):
|
|
393
|
+
messages.append("Applied memory guard for this machine: capped batch/seq/accum to avoid GPU/UM pressure")
|
|
394
|
+
|
|
395
|
+
# Model-specific guidance
|
|
396
|
+
if config.model_size_category == "tiny":
|
|
397
|
+
messages.append(f"Detected tiny model ({config.estimated_parameters_b:.1f}B params) - using higher learning rate for better convergence")
|
|
398
|
+
elif config.model_size_category == "small":
|
|
399
|
+
messages.append(f"Detected small model ({config.estimated_parameters_b:.1f}B params) - using optimized settings")
|
|
400
|
+
elif config.model_size_category == "large":
|
|
401
|
+
messages.append(f"Detected large model ({config.estimated_parameters_b:.1f}B params) - using careful settings for stability")
|
|
402
|
+
elif config.model_size_category == "xlarge":
|
|
403
|
+
messages.append(f"Detected very large model ({config.estimated_parameters_b:.1f}B params) - using conservative settings for stability")
|
|
404
|
+
|
|
405
|
+
# Learning rate guidance
|
|
406
|
+
if config.learning_rate > 1e-4:
|
|
407
|
+
messages.append(f"Using accelerated learning rate ({config.learning_rate:.1e}) - suitable for smaller models")
|
|
408
|
+
elif config.learning_rate < 5e-5:
|
|
409
|
+
messages.append(f"Using conservative learning rate ({config.learning_rate:.1e}) - prevents overfitting in large models")
|
|
410
|
+
|
|
411
|
+
# LoRA guidance
|
|
412
|
+
if config.lora_r >= 32:
|
|
413
|
+
messages.append(f"Using high LoRA rank ({config.lora_r}) - captures more model complexity")
|
|
414
|
+
elif config.lora_r <= 8:
|
|
415
|
+
messages.append(f"Using low LoRA rank ({config.lora_r}) - efficient for simpler adaptations")
|
|
416
|
+
|
|
417
|
+
# Epoch guidance
|
|
418
|
+
if config.epochs >= 5:
|
|
419
|
+
messages.append(f"Training for {config.epochs} epochs - extra iterations for small datasets")
|
|
420
|
+
elif config.epochs == 1:
|
|
421
|
+
messages.append(f"Single epoch training - suitable for large datasets")
|
|
422
|
+
|
|
423
|
+
if not messages:
|
|
424
|
+
messages.append(f"Using optimized settings for {config.model_size_category} model")
|
|
425
|
+
|
|
426
|
+
return "\n ".join(messages)
|
|
427
|
+
|
|
428
|
+
@staticmethod
|
|
429
|
+
def _get_total_memory_gb() -> Optional[float]:
|
|
430
|
+
"""Approximate total unified memory on macOS (used as GPU-visible memory)."""
|
|
431
|
+
try:
|
|
432
|
+
page_size = os.sysconf("SC_PAGE_SIZE")
|
|
433
|
+
phys_pages = os.sysconf("SC_PHYS_PAGES")
|
|
434
|
+
total_bytes = page_size * phys_pages
|
|
435
|
+
return round(total_bytes / (1024**3), 1)
|
|
436
|
+
except Exception as exc: # noqa: BLE001
|
|
437
|
+
logger.debug(f"Total memory detection failed: {exc}")
|
|
438
|
+
return None
|
|
439
|
+
|
|
440
|
+
@classmethod
|
|
441
|
+
def _apply_memory_guards(cls, cfg: Dict[str, Any], total_mem_gb: Optional[float]) -> bool:
|
|
442
|
+
"""
|
|
443
|
+
Downscale aggressive settings on lower-memory Apple Silicon to reduce GPU/UM hangs.
|
|
444
|
+
|
|
445
|
+
Heuristics:
|
|
446
|
+
- <=16GB: cap seq length to 1024, batch=1, grad_acc<=2
|
|
447
|
+
- <=32GB: cap seq length to 2048, batch<=2, grad_acc<=4
|
|
448
|
+
- Additionally cap effective tokens (batch*grad_acc*max_seq) to avoid runaway memory.
|
|
449
|
+
"""
|
|
450
|
+
if not total_mem_gb:
|
|
451
|
+
return False
|
|
452
|
+
|
|
453
|
+
guard_applied = False
|
|
454
|
+
effective_tokens = lambda c: c["batch_size"] * c["gradient_accumulation_steps"] * c.get("max_sequence_length", 2048)
|
|
455
|
+
|
|
456
|
+
if total_mem_gb <= 16:
|
|
457
|
+
if cfg["batch_size"] > 1:
|
|
458
|
+
cfg["batch_size"] = 1
|
|
459
|
+
guard_applied = True
|
|
460
|
+
if cfg["gradient_accumulation_steps"] > 2:
|
|
461
|
+
cfg["gradient_accumulation_steps"] = 2
|
|
462
|
+
guard_applied = True
|
|
463
|
+
max_seq = cfg.get("max_sequence_length", 2048)
|
|
464
|
+
if max_seq > 1024:
|
|
465
|
+
cfg["max_sequence_length"] = 1024
|
|
466
|
+
guard_applied = True
|
|
467
|
+
target_tokens = 4096
|
|
468
|
+
elif total_mem_gb <= 32:
|
|
469
|
+
if cfg["batch_size"] > 2:
|
|
470
|
+
cfg["batch_size"] = 2
|
|
471
|
+
guard_applied = True
|
|
472
|
+
if cfg["gradient_accumulation_steps"] > 4:
|
|
473
|
+
cfg["gradient_accumulation_steps"] = 4
|
|
474
|
+
guard_applied = True
|
|
475
|
+
max_seq = cfg.get("max_sequence_length", 2048)
|
|
476
|
+
if max_seq > 2048:
|
|
477
|
+
cfg["max_sequence_length"] = 2048
|
|
478
|
+
guard_applied = True
|
|
479
|
+
target_tokens = 8192
|
|
480
|
+
else:
|
|
481
|
+
target_tokens = 12288 # Leave roomy settings for higher-memory hosts
|
|
482
|
+
|
|
483
|
+
# Gradient checkpointing trades compute for memory; enable when guarding.
|
|
484
|
+
if guard_applied and not cfg.get("gradient_checkpointing", False):
|
|
485
|
+
cfg["gradient_checkpointing"] = True
|
|
486
|
+
|
|
487
|
+
# If the overall token budget is still too high, scale down grad_acc first, then seq length.
|
|
488
|
+
curr_tokens = effective_tokens(cfg)
|
|
489
|
+
if curr_tokens > target_tokens:
|
|
490
|
+
scale = max(1, math.ceil(curr_tokens / target_tokens))
|
|
491
|
+
new_grad_acc = max(1, cfg["gradient_accumulation_steps"] // scale)
|
|
492
|
+
if new_grad_acc < cfg["gradient_accumulation_steps"]:
|
|
493
|
+
cfg["gradient_accumulation_steps"] = new_grad_acc
|
|
494
|
+
guard_applied = True
|
|
495
|
+
curr_tokens = effective_tokens(cfg)
|
|
496
|
+
if curr_tokens > target_tokens:
|
|
497
|
+
new_seq = max(256, cfg.get("max_sequence_length", 2048) // scale)
|
|
498
|
+
if new_seq < cfg.get("max_sequence_length", 2048):
|
|
499
|
+
cfg["max_sequence_length"] = new_seq
|
|
500
|
+
guard_applied = True
|
|
501
|
+
|
|
502
|
+
if guard_applied:
|
|
503
|
+
logger.info(
|
|
504
|
+
f"Memory guard applied (total_mem={total_mem_gb}GB): "
|
|
505
|
+
f"batch={cfg['batch_size']}, grad_acc={cfg['gradient_accumulation_steps']}, "
|
|
506
|
+
f"max_seq={cfg.get('max_sequence_length', 2048)}"
|
|
507
|
+
)
|
|
508
|
+
return guard_applied
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class LoRATrainer:
|
|
513
|
+
"""Trainer for LoRA fine-tuning using MLX."""
|
|
514
|
+
|
|
515
|
+
def __init__(self, model_manager: ModelManager, config: Config):
|
|
516
|
+
"""Initialize the trainer."""
|
|
517
|
+
self.model_manager = model_manager
|
|
518
|
+
self.config = config
|
|
519
|
+
self.mlx_accelerator = MLXAccelerator(MLXConfig())
|
|
520
|
+
|
|
521
|
+
def train(
|
|
522
|
+
self,
|
|
523
|
+
base_model_name: str,
|
|
524
|
+
dataset_path: Path,
|
|
525
|
+
output_name: str,
|
|
526
|
+
config: TrainingConfig,
|
|
527
|
+
progress_callback: Optional[Callable] = None
|
|
528
|
+
) -> bool:
|
|
529
|
+
"""
|
|
530
|
+
Train a model using LoRA.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
base_model_name: Name of the base model to fine-tune
|
|
534
|
+
dataset_path: Path to the training dataset
|
|
535
|
+
output_name: Name for the fine-tuned model
|
|
536
|
+
config: Training configuration
|
|
537
|
+
progress_callback: Optional callback for progress updates
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
True if training succeeded, False otherwise
|
|
541
|
+
"""
|
|
542
|
+
try:
|
|
543
|
+
if not MLX_AVAILABLE:
|
|
544
|
+
logger.error("MLX is not available; fine-tuning requires MLX.")
|
|
545
|
+
if "_MLX_IMPORT_ERROR" in globals():
|
|
546
|
+
logger.debug(f"MLX import error: {_MLX_IMPORT_ERROR}") # type: ignore[name-defined]
|
|
547
|
+
return False
|
|
548
|
+
logger.info(f"Starting LoRA training: {base_model_name} -> {output_name}")
|
|
549
|
+
|
|
550
|
+
# Step 1: Load base model
|
|
551
|
+
logger.info("Loading base model...")
|
|
552
|
+
model, tokenizer = self._load_base_model(base_model_name)
|
|
553
|
+
if model is None:
|
|
554
|
+
logger.error("Failed to load base model")
|
|
555
|
+
return False
|
|
556
|
+
|
|
557
|
+
# Step 2: Apply LoRA layers
|
|
558
|
+
logger.info(f"Applying LoRA with rank={config.lora_r}")
|
|
559
|
+
model = self._apply_lora(model, config)
|
|
560
|
+
|
|
561
|
+
# Step 3: Load and prepare dataset
|
|
562
|
+
logger.info("Loading dataset...")
|
|
563
|
+
train_dataset = self._load_dataset(dataset_path, tokenizer, config)
|
|
564
|
+
if train_dataset is None:
|
|
565
|
+
logger.error("Failed to load dataset")
|
|
566
|
+
return False
|
|
567
|
+
|
|
568
|
+
# Step 4: Setup optimizer
|
|
569
|
+
optimizer = self._setup_optimizer(model, config)
|
|
570
|
+
|
|
571
|
+
# Step 5: Training loop
|
|
572
|
+
logger.info(f"Starting training for {config.epochs} epochs...")
|
|
573
|
+
trained_model = self._training_loop(
|
|
574
|
+
model=model,
|
|
575
|
+
dataset=train_dataset,
|
|
576
|
+
optimizer=optimizer,
|
|
577
|
+
config=config,
|
|
578
|
+
tokenizer=tokenizer,
|
|
579
|
+
progress_callback=progress_callback
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
# Step 6: Save fine-tuned model
|
|
583
|
+
logger.info(f"Saving fine-tuned model as {output_name}...")
|
|
584
|
+
success = self._save_model(trained_model, tokenizer, output_name, base_model_name)
|
|
585
|
+
|
|
586
|
+
if success:
|
|
587
|
+
logger.info(f"Successfully fine-tuned model saved as {output_name}")
|
|
588
|
+
return True
|
|
589
|
+
else:
|
|
590
|
+
logger.error("Failed to save fine-tuned model")
|
|
591
|
+
return False
|
|
592
|
+
|
|
593
|
+
except Exception as e:
|
|
594
|
+
logger.error(f"Training failed: {e}")
|
|
595
|
+
return False
|
|
596
|
+
|
|
597
|
+
def _load_base_model(self, model_name: str) -> Tuple[Optional[Any], Optional[Any]]:
|
|
598
|
+
"""Load the base model and tokenizer."""
|
|
599
|
+
try:
|
|
600
|
+
# The model should already be loaded by the ModelManager
|
|
601
|
+
# We just need to get it from the cache
|
|
602
|
+
|
|
603
|
+
# Try all possible cache keys
|
|
604
|
+
possible_keys = [
|
|
605
|
+
model_name,
|
|
606
|
+
self.model_manager.current_model,
|
|
607
|
+
# Sometimes the model is stored with path as key
|
|
608
|
+
str(Path.home() / ".cortex" / "mlx_models" / model_name),
|
|
609
|
+
]
|
|
610
|
+
|
|
611
|
+
model = None
|
|
612
|
+
tokenizer = None
|
|
613
|
+
|
|
614
|
+
for key in possible_keys:
|
|
615
|
+
if key and not model:
|
|
616
|
+
model = self.model_manager.model_cache.get(key)
|
|
617
|
+
if key and not tokenizer:
|
|
618
|
+
tokenizer = self.model_manager.tokenizers.get(key)
|
|
619
|
+
|
|
620
|
+
if model and tokenizer:
|
|
621
|
+
logger.info(f"Using loaded model from cache (key: {key})")
|
|
622
|
+
break
|
|
623
|
+
|
|
624
|
+
if model and tokenizer:
|
|
625
|
+
return model, tokenizer
|
|
626
|
+
|
|
627
|
+
# If not in cache, this is unexpected since the wizard confirmed the model is loaded
|
|
628
|
+
logger.error(f"Model {model_name} not found in cache. Available keys: {list(self.model_manager.model_cache.keys())}")
|
|
629
|
+
logger.error(f"Current model: {self.model_manager.current_model}")
|
|
630
|
+
|
|
631
|
+
# As a fallback, try to load it (but this shouldn't happen)
|
|
632
|
+
logger.warning(f"Attempting to reload model {model_name}")
|
|
633
|
+
|
|
634
|
+
# First check if it's already an MLX model to avoid re-conversion
|
|
635
|
+
mlx_path = Path.home() / ".cortex" / "mlx_models" / model_name
|
|
636
|
+
if mlx_path.exists():
|
|
637
|
+
# It's already converted, load it directly
|
|
638
|
+
success, message = self.model_manager.load_model(str(mlx_path), model_name=model_name)
|
|
639
|
+
else:
|
|
640
|
+
# Try loading from original location
|
|
641
|
+
success, message = self.model_manager.load_model(model_name)
|
|
642
|
+
|
|
643
|
+
if not success:
|
|
644
|
+
logger.error(f"Failed to load model: {message}")
|
|
645
|
+
return None, None
|
|
646
|
+
|
|
647
|
+
# Try to get it from cache again
|
|
648
|
+
model = self.model_manager.model_cache.get(model_name) or self.model_manager.model_cache.get(self.model_manager.current_model)
|
|
649
|
+
tokenizer = self.model_manager.tokenizers.get(model_name) or self.model_manager.tokenizers.get(self.model_manager.current_model)
|
|
650
|
+
|
|
651
|
+
if not model or not tokenizer:
|
|
652
|
+
logger.error(f"Model or tokenizer still not available after reload")
|
|
653
|
+
return None, None
|
|
654
|
+
|
|
655
|
+
return model, tokenizer
|
|
656
|
+
|
|
657
|
+
except Exception as e:
|
|
658
|
+
logger.error(f"Error loading base model: {e}")
|
|
659
|
+
return None, None
|
|
660
|
+
|
|
661
|
+
def _apply_lora(self, model: Any, config: TrainingConfig) -> Any:
|
|
662
|
+
"""Apply LoRA layers to the model."""
|
|
663
|
+
if LoRALinear is None:
|
|
664
|
+
# Fallback: Simple LoRA implementation
|
|
665
|
+
logger.warning("mlx_lm LoRA not available, using basic implementation")
|
|
666
|
+
return self._apply_basic_lora(model, config)
|
|
667
|
+
|
|
668
|
+
# Use mlx_lm's LoRA implementation
|
|
669
|
+
lora_layers = 0
|
|
670
|
+
|
|
671
|
+
def apply_lora_to_linear(layer):
|
|
672
|
+
nonlocal lora_layers
|
|
673
|
+
if isinstance(layer, nn.Linear):
|
|
674
|
+
# Check if this is a target module
|
|
675
|
+
for target in config.target_modules:
|
|
676
|
+
if hasattr(layer, '__name__') and target in str(layer.__name__):
|
|
677
|
+
# Replace with LoRA layer
|
|
678
|
+
lora_layers += 1
|
|
679
|
+
return LoRALinear(
|
|
680
|
+
in_features=layer.weight.shape[1],
|
|
681
|
+
out_features=layer.weight.shape[0],
|
|
682
|
+
r=config.lora_r,
|
|
683
|
+
alpha=config.lora_alpha,
|
|
684
|
+
dropout=config.lora_dropout
|
|
685
|
+
)
|
|
686
|
+
return layer
|
|
687
|
+
return layer
|
|
688
|
+
|
|
689
|
+
# Apply LoRA to all linear layers in target modules
|
|
690
|
+
model = tree_map(apply_lora_to_linear, model)
|
|
691
|
+
logger.info(f"Applied LoRA to {lora_layers} layers")
|
|
692
|
+
|
|
693
|
+
return model
|
|
694
|
+
|
|
695
|
+
def _apply_basic_lora(self, model: Any, config: TrainingConfig) -> Any:
|
|
696
|
+
"""Apply basic LoRA implementation."""
|
|
697
|
+
class BasicLoRALinear(nn.Module):
|
|
698
|
+
def __init__(self, linear_layer, r=16, alpha=32):
|
|
699
|
+
super().__init__()
|
|
700
|
+
self.linear = linear_layer
|
|
701
|
+
self.r = r
|
|
702
|
+
self.alpha = alpha
|
|
703
|
+
|
|
704
|
+
# LoRA parameters
|
|
705
|
+
in_features = linear_layer.weight.shape[1]
|
|
706
|
+
out_features = linear_layer.weight.shape[0]
|
|
707
|
+
|
|
708
|
+
# Low-rank matrices
|
|
709
|
+
self.lora_a = mx.random.normal((r, in_features)) * 0.01
|
|
710
|
+
self.lora_b = mx.zeros((out_features, r))
|
|
711
|
+
|
|
712
|
+
# Scaling factor
|
|
713
|
+
self.scaling = alpha / r
|
|
714
|
+
|
|
715
|
+
def __call__(self, x):
|
|
716
|
+
# Original forward pass
|
|
717
|
+
result = self.linear(x)
|
|
718
|
+
|
|
719
|
+
# Add LoRA contribution
|
|
720
|
+
lora_out = x @ self.lora_a.T @ self.lora_b.T * self.scaling
|
|
721
|
+
|
|
722
|
+
return result + lora_out
|
|
723
|
+
|
|
724
|
+
# Apply to target modules
|
|
725
|
+
def apply_basic_lora_to_layer(layer):
|
|
726
|
+
if isinstance(layer, nn.Linear):
|
|
727
|
+
return BasicLoRALinear(layer, r=config.lora_r, alpha=config.lora_alpha)
|
|
728
|
+
return layer
|
|
729
|
+
|
|
730
|
+
model = tree_map(apply_basic_lora_to_layer, model)
|
|
731
|
+
return model
|
|
732
|
+
|
|
733
|
+
def _load_dataset(self, dataset_path: Path, tokenizer: Any, config: TrainingConfig) -> Optional[Any]:
|
|
734
|
+
"""Load and prepare the dataset."""
|
|
735
|
+
try:
|
|
736
|
+
# Load JSONL dataset
|
|
737
|
+
examples = []
|
|
738
|
+
with open(dataset_path, 'r') as f:
|
|
739
|
+
for line in f:
|
|
740
|
+
data = json.loads(line.strip())
|
|
741
|
+
examples.append(data)
|
|
742
|
+
|
|
743
|
+
# Tokenize examples
|
|
744
|
+
tokenized_examples = []
|
|
745
|
+
max_seq_len = getattr(config, "max_sequence_length", None)
|
|
746
|
+
for example in examples:
|
|
747
|
+
# Format as conversation
|
|
748
|
+
if 'prompt' in example and 'response' in example:
|
|
749
|
+
text = f"User: {example['prompt']}\nAssistant: {example['response']}"
|
|
750
|
+
elif 'text' in example:
|
|
751
|
+
text = example['text']
|
|
752
|
+
else:
|
|
753
|
+
continue
|
|
754
|
+
|
|
755
|
+
# Tokenize
|
|
756
|
+
tokens = tokenizer.encode(text)
|
|
757
|
+
if max_seq_len and len(tokens) > max_seq_len:
|
|
758
|
+
tokens = tokens[:max_seq_len]
|
|
759
|
+
tokenized_examples.append({
|
|
760
|
+
'input_ids': mx.array(tokens),
|
|
761
|
+
'labels': mx.array(tokens) # For causal LM
|
|
762
|
+
})
|
|
763
|
+
|
|
764
|
+
logger.info(f"Loaded {len(tokenized_examples)} training examples")
|
|
765
|
+
return tokenized_examples
|
|
766
|
+
|
|
767
|
+
except Exception as e:
|
|
768
|
+
logger.error(f"Error loading dataset: {e}")
|
|
769
|
+
return None
|
|
770
|
+
|
|
771
|
+
def _setup_optimizer(self, model: Any, config: TrainingConfig) -> Any:
|
|
772
|
+
"""Setup the optimizer."""
|
|
773
|
+
# Get trainable parameters (LoRA parameters only)
|
|
774
|
+
trainable_params = []
|
|
775
|
+
|
|
776
|
+
def get_lora_params(module, prefix=""):
|
|
777
|
+
# Check for LoRA parameters
|
|
778
|
+
if hasattr(module, 'lora_a'):
|
|
779
|
+
trainable_params.append(module.lora_a)
|
|
780
|
+
if hasattr(module, 'lora_b'):
|
|
781
|
+
trainable_params.append(module.lora_b)
|
|
782
|
+
|
|
783
|
+
# Try to iterate over child modules
|
|
784
|
+
try:
|
|
785
|
+
# Try vars() first (for regular Python objects)
|
|
786
|
+
children = vars(module).items()
|
|
787
|
+
except TypeError:
|
|
788
|
+
# If vars() doesn't work, try __dict__ directly
|
|
789
|
+
if hasattr(module, '__dict__'):
|
|
790
|
+
children = module.__dict__.items()
|
|
791
|
+
else:
|
|
792
|
+
# For MLX modules, try to get children differently
|
|
793
|
+
children = []
|
|
794
|
+
if hasattr(module, 'children'):
|
|
795
|
+
for child in module.children():
|
|
796
|
+
children.append(('', child))
|
|
797
|
+
|
|
798
|
+
for name, child in children:
|
|
799
|
+
if isinstance(child, nn.Module):
|
|
800
|
+
get_lora_params(child, f"{prefix}.{name}")
|
|
801
|
+
|
|
802
|
+
# Only try to extract LoRA params if model is a Module
|
|
803
|
+
if isinstance(model, nn.Module):
|
|
804
|
+
get_lora_params(model)
|
|
805
|
+
|
|
806
|
+
if not trainable_params:
|
|
807
|
+
# If no LoRA params found, train all parameters (fallback)
|
|
808
|
+
logger.warning("No LoRA parameters found, training all parameters")
|
|
809
|
+
# For MLX models, we need to get parameters differently
|
|
810
|
+
if hasattr(model, 'parameters'):
|
|
811
|
+
trainable_params = list(model.parameters())
|
|
812
|
+
else:
|
|
813
|
+
logger.error("Model has no parameters() method")
|
|
814
|
+
trainable_params = []
|
|
815
|
+
|
|
816
|
+
# Create optimizer
|
|
817
|
+
optimizer = optim.AdamW(
|
|
818
|
+
learning_rate=config.learning_rate,
|
|
819
|
+
weight_decay=config.weight_decay
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
# Initialize optimizer state
|
|
823
|
+
optimizer.init(trainable_params)
|
|
824
|
+
|
|
825
|
+
logger.info(f"Initialized optimizer with {len(trainable_params)} trainable parameters")
|
|
826
|
+
return optimizer
|
|
827
|
+
|
|
828
|
+
def _training_loop(
|
|
829
|
+
self,
|
|
830
|
+
model: Any,
|
|
831
|
+
dataset: list,
|
|
832
|
+
optimizer: Any,
|
|
833
|
+
config: TrainingConfig,
|
|
834
|
+
tokenizer: Any,
|
|
835
|
+
progress_callback: Optional[Callable] = None
|
|
836
|
+
) -> Any:
|
|
837
|
+
"""Main training loop."""
|
|
838
|
+
model.train()
|
|
839
|
+
|
|
840
|
+
total_steps = len(dataset) * config.epochs
|
|
841
|
+
current_step = 0
|
|
842
|
+
|
|
843
|
+
for epoch in range(config.epochs):
|
|
844
|
+
epoch_loss = 0.0
|
|
845
|
+
batch_loss = 0.0
|
|
846
|
+
|
|
847
|
+
for i, batch in enumerate(dataset):
|
|
848
|
+
# Forward pass
|
|
849
|
+
input_ids = batch['input_ids']
|
|
850
|
+
labels = batch['labels']
|
|
851
|
+
|
|
852
|
+
# Compute loss
|
|
853
|
+
logits = model(input_ids[None, :]) # Add batch dimension
|
|
854
|
+
|
|
855
|
+
# Cross-entropy loss
|
|
856
|
+
loss = mx.mean(
|
|
857
|
+
nn.losses.cross_entropy(
|
|
858
|
+
logits[0, :-1], # All but last prediction
|
|
859
|
+
labels[1:], # All but first token
|
|
860
|
+
reduction='none'
|
|
861
|
+
)
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# Backward pass
|
|
865
|
+
loss_value, grads = mx.value_and_grad(lambda m: loss)(model)
|
|
866
|
+
|
|
867
|
+
# Gradient accumulation
|
|
868
|
+
batch_loss += loss_value.item()
|
|
869
|
+
|
|
870
|
+
if (i + 1) % config.gradient_accumulation_steps == 0:
|
|
871
|
+
# Update weights
|
|
872
|
+
optimizer.update(model, grads)
|
|
873
|
+
|
|
874
|
+
# Clear accumulated loss
|
|
875
|
+
avg_loss = batch_loss / config.gradient_accumulation_steps
|
|
876
|
+
epoch_loss += avg_loss
|
|
877
|
+
batch_loss = 0.0
|
|
878
|
+
|
|
879
|
+
# Progress callback
|
|
880
|
+
if progress_callback:
|
|
881
|
+
progress_callback(epoch, i, avg_loss)
|
|
882
|
+
|
|
883
|
+
current_step += 1
|
|
884
|
+
|
|
885
|
+
# Evaluate to ensure computation
|
|
886
|
+
mx.eval(model.parameters())
|
|
887
|
+
|
|
888
|
+
# Log epoch statistics
|
|
889
|
+
avg_epoch_loss = epoch_loss / (len(dataset) / config.gradient_accumulation_steps)
|
|
890
|
+
logger.info(f"Epoch {epoch+1}/{config.epochs} - Loss: {avg_epoch_loss:.4f}")
|
|
891
|
+
|
|
892
|
+
return model
|
|
893
|
+
|
|
894
|
+
def _save_model(
|
|
895
|
+
self,
|
|
896
|
+
model: Any,
|
|
897
|
+
tokenizer: Any,
|
|
898
|
+
output_name: str,
|
|
899
|
+
base_model_name: str
|
|
900
|
+
) -> bool:
|
|
901
|
+
"""Save the fine-tuned model."""
|
|
902
|
+
try:
|
|
903
|
+
# Create output directory in MLX models folder for consistency
|
|
904
|
+
output_dir = Path.home() / ".cortex" / "mlx_models" / output_name
|
|
905
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
906
|
+
|
|
907
|
+
# Save model weights
|
|
908
|
+
weights_path = output_dir / "model.safetensors"
|
|
909
|
+
|
|
910
|
+
# Get model state dict
|
|
911
|
+
state_dict = {}
|
|
912
|
+
|
|
913
|
+
def extract_weights(module, prefix=""):
|
|
914
|
+
for name, param in vars(module).items():
|
|
915
|
+
if isinstance(param, mx.array):
|
|
916
|
+
state_dict[f"{prefix}.{name}"] = param
|
|
917
|
+
elif isinstance(param, nn.Module):
|
|
918
|
+
extract_weights(param, f"{prefix}.{name}")
|
|
919
|
+
|
|
920
|
+
extract_weights(model)
|
|
921
|
+
|
|
922
|
+
# Save using safetensors format (or numpy for simplicity)
|
|
923
|
+
import numpy as np
|
|
924
|
+
np_state_dict = {k: v.tolist() for k, v in state_dict.items()}
|
|
925
|
+
|
|
926
|
+
with open(weights_path, 'w') as f:
|
|
927
|
+
json.dump(np_state_dict, f)
|
|
928
|
+
|
|
929
|
+
# Save tokenizer
|
|
930
|
+
if hasattr(tokenizer, 'save_pretrained'):
|
|
931
|
+
tokenizer.save_pretrained(output_dir)
|
|
932
|
+
|
|
933
|
+
# Save config
|
|
934
|
+
config_data = {
|
|
935
|
+
"base_model": base_model_name,
|
|
936
|
+
"model_type": "fine-tuned",
|
|
937
|
+
"fine_tuning_method": "LoRA",
|
|
938
|
+
"created_at": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
with open(output_dir / "config.json", 'w') as f:
|
|
942
|
+
json.dump(config_data, f, indent=2)
|
|
943
|
+
|
|
944
|
+
# Copy any additional files from base model
|
|
945
|
+
base_model_path = Path.home() / ".cortex" / "models" / base_model_name
|
|
946
|
+
if base_model_path.exists():
|
|
947
|
+
for file in ['tokenizer_config.json', 'special_tokens_map.json', 'vocab.json']:
|
|
948
|
+
src = base_model_path / file
|
|
949
|
+
if src.exists():
|
|
950
|
+
shutil.copy2(src, output_dir / file)
|
|
951
|
+
|
|
952
|
+
logger.info(f"Model saved to {output_dir}")
|
|
953
|
+
return True
|
|
954
|
+
|
|
955
|
+
except Exception as e:
|
|
956
|
+
logger.error(f"Error saving model: {e}")
|
|
957
|
+
return False
|