cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,957 @@
1
+ """LoRA training implementation using MLX."""
2
+
3
+ import logging
4
+ import time
5
+ import os
6
+ import math
7
+ from pathlib import Path
8
+ from typing import Optional, Dict, Any, Callable, Tuple
9
+ from dataclasses import dataclass
10
+ import json
11
+ import shutil
12
+
13
+ try:
14
+ import mlx.core as mx
15
+ import mlx.nn as nn
16
+ import mlx.optimizers as optim
17
+ from mlx.utils import tree_map
18
+ MLX_AVAILABLE = True
19
+ except Exception as exc: # noqa: BLE001
20
+ MLX_AVAILABLE = False
21
+ mx = nn = optim = tree_map = None # type: ignore
22
+ _MLX_IMPORT_ERROR = exc
23
+
24
+ # Import MLX LM functions
25
+ try:
26
+ from mlx_lm import load as mlx_load
27
+ from mlx_lm.tuner.lora import LoRALinear
28
+ from mlx_lm.tuner.trainer import TrainingArgs, train as mlx_train
29
+ from mlx_lm.tuner.datasets import load_dataset as mlx_load_dataset
30
+ except ImportError:
31
+ # Fallback implementations
32
+ mlx_load = None
33
+ LoRALinear = None
34
+ TrainingArgs = None
35
+ mlx_train = None
36
+ mlx_load_dataset = None
37
+
38
+ from cortex.model_manager import ModelManager
39
+ from cortex.config import Config
40
+ from cortex.metal.mlx_accelerator import MLXAccelerator, MLXConfig
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @dataclass
46
+ class TrainingConfig:
47
+ """Enhanced configuration for fine-tuning with intelligent defaults."""
48
+ # Core training parameters
49
+ epochs: int = 2
50
+ learning_rate: float = 3e-5
51
+ batch_size: int = 1
52
+ gradient_accumulation_steps: int = 4
53
+
54
+ # LoRA parameters
55
+ lora_r: int = 16 # LoRA rank
56
+ lora_alpha: int = 32 # LoRA alpha
57
+ lora_dropout: float = 0.1
58
+ target_modules: list = None # Auto-detect if None
59
+ num_lora_layers: int = 16 # Number of layers to apply LoRA to
60
+
61
+ # Optimization parameters
62
+ optimizer_type: str = "adamw" # adamw, sgd, adafactor
63
+ weight_decay: float = 0.01
64
+ max_grad_norm: float = 1.0
65
+ warmup_steps: Optional[int] = None # If None, calculated from warmup_ratio
66
+ warmup_ratio: float = 0.1
67
+ lr_scheduler: str = "linear" # linear, cosine, constant, polynomial
68
+
69
+ # Memory and performance
70
+ gradient_checkpointing: bool = False
71
+ quantization_bits: Optional[int] = None # 4 or 8 bit quantization
72
+ dataloader_num_workers: int = 0
73
+ fp16: bool = True
74
+ bf16: bool = False
75
+
76
+ # Task-specific settings
77
+ task_type: str = "chat" # chat, completion, structured
78
+ max_sequence_length: int = 2048
79
+ response_template: Optional[str] = None
80
+
81
+ # Dataset settings
82
+ train_test_split: float = 0.0 # If > 0, split dataset for validation
83
+ shuffle_dataset: bool = True
84
+
85
+ # Advanced settings
86
+ seed: int = 42
87
+ logging_steps: int = 10
88
+ eval_steps: Optional[int] = None
89
+ save_steps: int = 500
90
+ early_stopping_patience: Optional[int] = None
91
+
92
+ # Model-aware settings (populated automatically)
93
+ model_size_category: str = "medium" # tiny, small, medium, large, xlarge
94
+ estimated_parameters_b: float = 2.0 # Estimated parameters in billions
95
+ auto_configured: bool = False # Whether config was auto-generated
96
+ configuration_source: str = "manual" # manual, smart_quick, smart_balanced, smart_quality
97
+
98
+ def __post_init__(self):
99
+ if self.target_modules is None:
100
+ # Default target modules for LoRA
101
+ self.target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
102
+
103
+ def validate(self) -> Tuple[bool, str]:
104
+ """Validate configuration settings."""
105
+ if self.learning_rate <= 0 or self.learning_rate > 1:
106
+ return False, f"Invalid learning rate: {self.learning_rate}"
107
+
108
+ if self.epochs < 1 or self.epochs > 100:
109
+ return False, f"Invalid number of epochs: {self.epochs}"
110
+
111
+ if self.batch_size < 1 or self.batch_size > 128:
112
+ return False, f"Invalid batch size: {self.batch_size}"
113
+
114
+ if self.lora_r < 1 or self.lora_r > 256:
115
+ return False, f"Invalid LoRA rank: {self.lora_r}"
116
+
117
+ if self.quantization_bits and self.quantization_bits not in [4, 8]:
118
+ return False, f"Invalid quantization bits: {self.quantization_bits}. Must be 4 or 8."
119
+
120
+ return True, "Configuration is valid"
121
+
122
+
123
+ class SmartConfigFactory:
124
+ """Factory for creating intelligent training configurations based on model and data characteristics."""
125
+
126
+ # Model size categories (parameters in billions)
127
+ MODEL_CATEGORIES = {
128
+ "tiny": (0, 0.5), # < 500M parameters (e.g., DistilBERT, small GPT-2)
129
+ "small": (0.5, 2), # 500M-2B (e.g., GPT-2, small Llama)
130
+ "medium": (2, 8), # 2B-8B (e.g., Gemma-7B, Llama-2-7B)
131
+ "large": (8, 20), # 8B-20B (e.g., Llama-2-13B, Mistral-7B variants)
132
+ "xlarge": (20, float('inf')) # 20B+ (e.g., Llama-2-70B, GPT-3.5+)
133
+ }
134
+
135
+ # Optimal settings by model size category
136
+ CATEGORY_DEFAULTS = {
137
+ "tiny": {
138
+ "learning_rate": 5e-4, # Higher LR for small models
139
+ "epochs": 5, # More epochs needed
140
+ "lora_r": 8, # Lower rank sufficient
141
+ "lora_alpha": 16,
142
+ "batch_size": 4, # Can handle larger batches
143
+ "gradient_accumulation_steps": 2,
144
+ "warmup_ratio": 0.05, # Less warmup needed
145
+ "weight_decay": 0.001, # Less regularization
146
+ },
147
+ "small": {
148
+ "learning_rate": 3e-4,
149
+ "epochs": 4,
150
+ "lora_r": 16,
151
+ "lora_alpha": 32,
152
+ "batch_size": 2,
153
+ "gradient_accumulation_steps": 4,
154
+ "warmup_ratio": 0.1,
155
+ "weight_decay": 0.01,
156
+ },
157
+ "medium": {
158
+ "learning_rate": 1e-4, # Standard settings for most models
159
+ "epochs": 3,
160
+ "lora_r": 16,
161
+ "lora_alpha": 32,
162
+ "batch_size": 1,
163
+ "gradient_accumulation_steps": 8,
164
+ "warmup_ratio": 0.1,
165
+ "weight_decay": 0.01,
166
+ },
167
+ "large": {
168
+ "learning_rate": 5e-5, # Lower LR for stability
169
+ "epochs": 2,
170
+ "lora_r": 32, # Higher rank for complex models
171
+ "lora_alpha": 64,
172
+ "batch_size": 1,
173
+ "gradient_accumulation_steps": 16,
174
+ "warmup_ratio": 0.15, # More warmup
175
+ "weight_decay": 0.01,
176
+ },
177
+ "xlarge": {
178
+ "learning_rate": 2e-5, # Very conservative
179
+ "epochs": 2,
180
+ "lora_r": 64, # High rank for very large models
181
+ "lora_alpha": 128,
182
+ "batch_size": 1,
183
+ "gradient_accumulation_steps": 32,
184
+ "warmup_ratio": 0.2,
185
+ "weight_decay": 0.01,
186
+ }
187
+ }
188
+
189
+ @classmethod
190
+ def categorize_model_size(cls, size_gb: float, model_manager=None, model_path=None) -> Tuple[str, float]:
191
+ """Categorize model based on actual parameters if possible, fallback to size estimation."""
192
+ estimated_params_b = size_gb / 2.0 # Fallback estimation
193
+
194
+ # Try to get accurate parameter count if model_manager and path are provided
195
+ if model_manager and model_path:
196
+ try:
197
+ from pathlib import Path
198
+ actual_params_b = model_manager.get_model_parameters_smart(Path(model_path))
199
+ if actual_params_b is not None:
200
+ estimated_params_b = actual_params_b # Already in billions
201
+ logger.info(f"Using accurate parameter count: {estimated_params_b:.2f}B parameters")
202
+ else:
203
+ logger.warning(f"Could not detect parameters, using size estimation: {estimated_params_b:.2f}B")
204
+ except Exception as e:
205
+ logger.warning(f"Parameter detection failed: {e}, using size estimation")
206
+
207
+ for category, (min_params, max_params) in cls.MODEL_CATEGORIES.items():
208
+ if min_params <= estimated_params_b < max_params:
209
+ return category, estimated_params_b
210
+
211
+ # Fallback to medium if can't categorize
212
+ return "medium", estimated_params_b
213
+
214
+ @classmethod
215
+ def analyze_dataset(cls, dataset_path: Path) -> Dict[str, Any]:
216
+ """Analyze dataset to inform training configuration."""
217
+ try:
218
+ examples = []
219
+ with open(dataset_path, 'r') as f:
220
+ for line in f:
221
+ examples.append(json.loads(line.strip()))
222
+
223
+ dataset_size = len(examples)
224
+
225
+ # Analyze content to detect task type
226
+ task_type = "chat" # Default
227
+ avg_length = 0
228
+
229
+ if examples:
230
+ sample = examples[0]
231
+
232
+ # Detect task type from structure
233
+ if 'prompt' in sample and 'response' in sample:
234
+ task_type = "chat"
235
+ elif 'prompt' in sample and 'completion' in sample:
236
+ task_type = "completion"
237
+ elif 'text' in sample:
238
+ task_type = "completion"
239
+
240
+ # Calculate average text length
241
+ total_chars = 0
242
+ for example in examples[:100]: # Sample first 100
243
+ text = ""
244
+ if 'text' in example:
245
+ text = example['text']
246
+ elif 'prompt' in example and 'response' in example:
247
+ text = example['prompt'] + example['response']
248
+ elif 'prompt' in example and 'completion' in example:
249
+ text = example['prompt'] + example['completion']
250
+ total_chars += len(text)
251
+
252
+ avg_length = total_chars // min(len(examples), 100)
253
+
254
+ return {
255
+ "size": dataset_size,
256
+ "task_type": task_type,
257
+ "avg_length": avg_length,
258
+ "size_category": cls._get_dataset_size_category(dataset_size)
259
+ }
260
+ except Exception as e:
261
+ logger.warning(f"Failed to analyze dataset: {e}")
262
+ return {
263
+ "size": 0,
264
+ "task_type": "chat",
265
+ "avg_length": 1000,
266
+ "size_category": "small"
267
+ }
268
+
269
+ @classmethod
270
+ def _get_dataset_size_category(cls, size: int) -> str:
271
+ """Categorize dataset by size."""
272
+ if size < 50:
273
+ return "tiny"
274
+ elif size < 500:
275
+ return "small"
276
+ elif size < 2000:
277
+ return "medium"
278
+ elif size < 10000:
279
+ return "large"
280
+ else:
281
+ return "xlarge"
282
+
283
+ @classmethod
284
+ def create_smart_config(
285
+ cls,
286
+ model_size_gb: float,
287
+ dataset_path: Path,
288
+ preset: str = "balanced",
289
+ custom_settings: Optional[Dict[str, Any]] = None,
290
+ model_manager = None,
291
+ model_path: Optional[str] = None
292
+ ) -> TrainingConfig:
293
+ """Create an intelligent training configuration."""
294
+
295
+ # Analyze model with accurate parameter detection
296
+ model_category, estimated_params = cls.categorize_model_size(
297
+ model_size_gb, model_manager, model_path
298
+ )
299
+
300
+ # Analyze dataset
301
+ dataset_info = cls.analyze_dataset(dataset_path)
302
+
303
+ # Get base settings for model category
304
+ base_config = cls.CATEGORY_DEFAULTS[model_category].copy()
305
+
306
+ # Apply preset modifications
307
+ if preset == "quick":
308
+ base_config["epochs"] = max(1, base_config["epochs"] - 1)
309
+ base_config["learning_rate"] *= 1.5 # Faster learning
310
+ elif preset == "quality":
311
+ base_config["epochs"] += 1
312
+ base_config["learning_rate"] *= 0.8 # More conservative
313
+ base_config["lora_r"] = min(64, base_config["lora_r"] * 2) # Higher rank
314
+
315
+ # Adjust for dataset size
316
+ dataset_size = dataset_info["size"]
317
+ if dataset_size < 100: # Small dataset
318
+ base_config["epochs"] = min(base_config["epochs"] + 2, 8) # More epochs
319
+ base_config["weight_decay"] *= 0.5 # Less regularization
320
+ elif dataset_size > 5000: # Large dataset
321
+ base_config["epochs"] = max(1, base_config["epochs"] - 1) # Fewer epochs
322
+
323
+ # Adjust for sequence length
324
+ if dataset_info["avg_length"] > 2000:
325
+ base_config["gradient_accumulation_steps"] *= 2 # Handle memory
326
+ base_config["max_sequence_length"] = 4096
327
+
328
+ total_mem_gb = cls._get_total_memory_gb()
329
+ memory_guard_applied = cls._apply_memory_guards(base_config, total_mem_gb)
330
+
331
+ # Apply custom settings if provided
332
+ if custom_settings:
333
+ base_config.update(custom_settings)
334
+
335
+ # Create configuration
336
+ config = TrainingConfig(
337
+ # Core parameters
338
+ epochs=base_config["epochs"],
339
+ learning_rate=base_config["learning_rate"],
340
+ batch_size=base_config["batch_size"],
341
+ gradient_accumulation_steps=base_config["gradient_accumulation_steps"],
342
+
343
+ # LoRA parameters
344
+ lora_r=base_config["lora_r"],
345
+ lora_alpha=base_config["lora_alpha"],
346
+
347
+ # Optimization
348
+ weight_decay=base_config["weight_decay"],
349
+ warmup_ratio=base_config["warmup_ratio"],
350
+
351
+ # Task-specific
352
+ task_type=dataset_info["task_type"],
353
+ max_sequence_length=base_config.get("max_sequence_length", 2048),
354
+
355
+ # Metadata
356
+ model_size_category=model_category,
357
+ estimated_parameters_b=estimated_params,
358
+ auto_configured=True,
359
+ configuration_source=f"smart_{preset}{'_memory_guarded' if memory_guard_applied else ''}"
360
+ )
361
+
362
+ return config
363
+
364
+ @classmethod
365
+ def get_preset_configs(cls) -> Dict[str, Dict[str, Any]]:
366
+ """Get preset configuration descriptions."""
367
+ return {
368
+ "quick": {
369
+ "name": "Quick",
370
+ "description": "Fast training with fewer epochs",
371
+ "use_case": "Quick experimentation and testing",
372
+ "time_factor": 0.7
373
+ },
374
+ "balanced": {
375
+ "name": "Balanced",
376
+ "description": "Optimal balance of speed and quality",
377
+ "use_case": "Most general use cases (recommended)",
378
+ "time_factor": 1.0
379
+ },
380
+ "quality": {
381
+ "name": "Quality",
382
+ "description": "Best results with more training",
383
+ "use_case": "Production models, important tasks",
384
+ "time_factor": 1.5
385
+ }
386
+ }
387
+
388
+ @classmethod
389
+ def generate_guidance_message(cls, config: TrainingConfig, model_name: str) -> str:
390
+ """Generate helpful guidance message for the user."""
391
+ messages = []
392
+ if config.configuration_source.endswith("memory_guarded"):
393
+ messages.append("Applied memory guard for this machine: capped batch/seq/accum to avoid GPU/UM pressure")
394
+
395
+ # Model-specific guidance
396
+ if config.model_size_category == "tiny":
397
+ messages.append(f"Detected tiny model ({config.estimated_parameters_b:.1f}B params) - using higher learning rate for better convergence")
398
+ elif config.model_size_category == "small":
399
+ messages.append(f"Detected small model ({config.estimated_parameters_b:.1f}B params) - using optimized settings")
400
+ elif config.model_size_category == "large":
401
+ messages.append(f"Detected large model ({config.estimated_parameters_b:.1f}B params) - using careful settings for stability")
402
+ elif config.model_size_category == "xlarge":
403
+ messages.append(f"Detected very large model ({config.estimated_parameters_b:.1f}B params) - using conservative settings for stability")
404
+
405
+ # Learning rate guidance
406
+ if config.learning_rate > 1e-4:
407
+ messages.append(f"Using accelerated learning rate ({config.learning_rate:.1e}) - suitable for smaller models")
408
+ elif config.learning_rate < 5e-5:
409
+ messages.append(f"Using conservative learning rate ({config.learning_rate:.1e}) - prevents overfitting in large models")
410
+
411
+ # LoRA guidance
412
+ if config.lora_r >= 32:
413
+ messages.append(f"Using high LoRA rank ({config.lora_r}) - captures more model complexity")
414
+ elif config.lora_r <= 8:
415
+ messages.append(f"Using low LoRA rank ({config.lora_r}) - efficient for simpler adaptations")
416
+
417
+ # Epoch guidance
418
+ if config.epochs >= 5:
419
+ messages.append(f"Training for {config.epochs} epochs - extra iterations for small datasets")
420
+ elif config.epochs == 1:
421
+ messages.append(f"Single epoch training - suitable for large datasets")
422
+
423
+ if not messages:
424
+ messages.append(f"Using optimized settings for {config.model_size_category} model")
425
+
426
+ return "\n ".join(messages)
427
+
428
+ @staticmethod
429
+ def _get_total_memory_gb() -> Optional[float]:
430
+ """Approximate total unified memory on macOS (used as GPU-visible memory)."""
431
+ try:
432
+ page_size = os.sysconf("SC_PAGE_SIZE")
433
+ phys_pages = os.sysconf("SC_PHYS_PAGES")
434
+ total_bytes = page_size * phys_pages
435
+ return round(total_bytes / (1024**3), 1)
436
+ except Exception as exc: # noqa: BLE001
437
+ logger.debug(f"Total memory detection failed: {exc}")
438
+ return None
439
+
440
+ @classmethod
441
+ def _apply_memory_guards(cls, cfg: Dict[str, Any], total_mem_gb: Optional[float]) -> bool:
442
+ """
443
+ Downscale aggressive settings on lower-memory Apple Silicon to reduce GPU/UM hangs.
444
+
445
+ Heuristics:
446
+ - <=16GB: cap seq length to 1024, batch=1, grad_acc<=2
447
+ - <=32GB: cap seq length to 2048, batch<=2, grad_acc<=4
448
+ - Additionally cap effective tokens (batch*grad_acc*max_seq) to avoid runaway memory.
449
+ """
450
+ if not total_mem_gb:
451
+ return False
452
+
453
+ guard_applied = False
454
+ effective_tokens = lambda c: c["batch_size"] * c["gradient_accumulation_steps"] * c.get("max_sequence_length", 2048)
455
+
456
+ if total_mem_gb <= 16:
457
+ if cfg["batch_size"] > 1:
458
+ cfg["batch_size"] = 1
459
+ guard_applied = True
460
+ if cfg["gradient_accumulation_steps"] > 2:
461
+ cfg["gradient_accumulation_steps"] = 2
462
+ guard_applied = True
463
+ max_seq = cfg.get("max_sequence_length", 2048)
464
+ if max_seq > 1024:
465
+ cfg["max_sequence_length"] = 1024
466
+ guard_applied = True
467
+ target_tokens = 4096
468
+ elif total_mem_gb <= 32:
469
+ if cfg["batch_size"] > 2:
470
+ cfg["batch_size"] = 2
471
+ guard_applied = True
472
+ if cfg["gradient_accumulation_steps"] > 4:
473
+ cfg["gradient_accumulation_steps"] = 4
474
+ guard_applied = True
475
+ max_seq = cfg.get("max_sequence_length", 2048)
476
+ if max_seq > 2048:
477
+ cfg["max_sequence_length"] = 2048
478
+ guard_applied = True
479
+ target_tokens = 8192
480
+ else:
481
+ target_tokens = 12288 # Leave roomy settings for higher-memory hosts
482
+
483
+ # Gradient checkpointing trades compute for memory; enable when guarding.
484
+ if guard_applied and not cfg.get("gradient_checkpointing", False):
485
+ cfg["gradient_checkpointing"] = True
486
+
487
+ # If the overall token budget is still too high, scale down grad_acc first, then seq length.
488
+ curr_tokens = effective_tokens(cfg)
489
+ if curr_tokens > target_tokens:
490
+ scale = max(1, math.ceil(curr_tokens / target_tokens))
491
+ new_grad_acc = max(1, cfg["gradient_accumulation_steps"] // scale)
492
+ if new_grad_acc < cfg["gradient_accumulation_steps"]:
493
+ cfg["gradient_accumulation_steps"] = new_grad_acc
494
+ guard_applied = True
495
+ curr_tokens = effective_tokens(cfg)
496
+ if curr_tokens > target_tokens:
497
+ new_seq = max(256, cfg.get("max_sequence_length", 2048) // scale)
498
+ if new_seq < cfg.get("max_sequence_length", 2048):
499
+ cfg["max_sequence_length"] = new_seq
500
+ guard_applied = True
501
+
502
+ if guard_applied:
503
+ logger.info(
504
+ f"Memory guard applied (total_mem={total_mem_gb}GB): "
505
+ f"batch={cfg['batch_size']}, grad_acc={cfg['gradient_accumulation_steps']}, "
506
+ f"max_seq={cfg.get('max_sequence_length', 2048)}"
507
+ )
508
+ return guard_applied
509
+
510
+
511
+
512
+ class LoRATrainer:
513
+ """Trainer for LoRA fine-tuning using MLX."""
514
+
515
+ def __init__(self, model_manager: ModelManager, config: Config):
516
+ """Initialize the trainer."""
517
+ self.model_manager = model_manager
518
+ self.config = config
519
+ self.mlx_accelerator = MLXAccelerator(MLXConfig())
520
+
521
+ def train(
522
+ self,
523
+ base_model_name: str,
524
+ dataset_path: Path,
525
+ output_name: str,
526
+ config: TrainingConfig,
527
+ progress_callback: Optional[Callable] = None
528
+ ) -> bool:
529
+ """
530
+ Train a model using LoRA.
531
+
532
+ Args:
533
+ base_model_name: Name of the base model to fine-tune
534
+ dataset_path: Path to the training dataset
535
+ output_name: Name for the fine-tuned model
536
+ config: Training configuration
537
+ progress_callback: Optional callback for progress updates
538
+
539
+ Returns:
540
+ True if training succeeded, False otherwise
541
+ """
542
+ try:
543
+ if not MLX_AVAILABLE:
544
+ logger.error("MLX is not available; fine-tuning requires MLX.")
545
+ if "_MLX_IMPORT_ERROR" in globals():
546
+ logger.debug(f"MLX import error: {_MLX_IMPORT_ERROR}") # type: ignore[name-defined]
547
+ return False
548
+ logger.info(f"Starting LoRA training: {base_model_name} -> {output_name}")
549
+
550
+ # Step 1: Load base model
551
+ logger.info("Loading base model...")
552
+ model, tokenizer = self._load_base_model(base_model_name)
553
+ if model is None:
554
+ logger.error("Failed to load base model")
555
+ return False
556
+
557
+ # Step 2: Apply LoRA layers
558
+ logger.info(f"Applying LoRA with rank={config.lora_r}")
559
+ model = self._apply_lora(model, config)
560
+
561
+ # Step 3: Load and prepare dataset
562
+ logger.info("Loading dataset...")
563
+ train_dataset = self._load_dataset(dataset_path, tokenizer, config)
564
+ if train_dataset is None:
565
+ logger.error("Failed to load dataset")
566
+ return False
567
+
568
+ # Step 4: Setup optimizer
569
+ optimizer = self._setup_optimizer(model, config)
570
+
571
+ # Step 5: Training loop
572
+ logger.info(f"Starting training for {config.epochs} epochs...")
573
+ trained_model = self._training_loop(
574
+ model=model,
575
+ dataset=train_dataset,
576
+ optimizer=optimizer,
577
+ config=config,
578
+ tokenizer=tokenizer,
579
+ progress_callback=progress_callback
580
+ )
581
+
582
+ # Step 6: Save fine-tuned model
583
+ logger.info(f"Saving fine-tuned model as {output_name}...")
584
+ success = self._save_model(trained_model, tokenizer, output_name, base_model_name)
585
+
586
+ if success:
587
+ logger.info(f"Successfully fine-tuned model saved as {output_name}")
588
+ return True
589
+ else:
590
+ logger.error("Failed to save fine-tuned model")
591
+ return False
592
+
593
+ except Exception as e:
594
+ logger.error(f"Training failed: {e}")
595
+ return False
596
+
597
+ def _load_base_model(self, model_name: str) -> Tuple[Optional[Any], Optional[Any]]:
598
+ """Load the base model and tokenizer."""
599
+ try:
600
+ # The model should already be loaded by the ModelManager
601
+ # We just need to get it from the cache
602
+
603
+ # Try all possible cache keys
604
+ possible_keys = [
605
+ model_name,
606
+ self.model_manager.current_model,
607
+ # Sometimes the model is stored with path as key
608
+ str(Path.home() / ".cortex" / "mlx_models" / model_name),
609
+ ]
610
+
611
+ model = None
612
+ tokenizer = None
613
+
614
+ for key in possible_keys:
615
+ if key and not model:
616
+ model = self.model_manager.model_cache.get(key)
617
+ if key and not tokenizer:
618
+ tokenizer = self.model_manager.tokenizers.get(key)
619
+
620
+ if model and tokenizer:
621
+ logger.info(f"Using loaded model from cache (key: {key})")
622
+ break
623
+
624
+ if model and tokenizer:
625
+ return model, tokenizer
626
+
627
+ # If not in cache, this is unexpected since the wizard confirmed the model is loaded
628
+ logger.error(f"Model {model_name} not found in cache. Available keys: {list(self.model_manager.model_cache.keys())}")
629
+ logger.error(f"Current model: {self.model_manager.current_model}")
630
+
631
+ # As a fallback, try to load it (but this shouldn't happen)
632
+ logger.warning(f"Attempting to reload model {model_name}")
633
+
634
+ # First check if it's already an MLX model to avoid re-conversion
635
+ mlx_path = Path.home() / ".cortex" / "mlx_models" / model_name
636
+ if mlx_path.exists():
637
+ # It's already converted, load it directly
638
+ success, message = self.model_manager.load_model(str(mlx_path), model_name=model_name)
639
+ else:
640
+ # Try loading from original location
641
+ success, message = self.model_manager.load_model(model_name)
642
+
643
+ if not success:
644
+ logger.error(f"Failed to load model: {message}")
645
+ return None, None
646
+
647
+ # Try to get it from cache again
648
+ model = self.model_manager.model_cache.get(model_name) or self.model_manager.model_cache.get(self.model_manager.current_model)
649
+ tokenizer = self.model_manager.tokenizers.get(model_name) or self.model_manager.tokenizers.get(self.model_manager.current_model)
650
+
651
+ if not model or not tokenizer:
652
+ logger.error(f"Model or tokenizer still not available after reload")
653
+ return None, None
654
+
655
+ return model, tokenizer
656
+
657
+ except Exception as e:
658
+ logger.error(f"Error loading base model: {e}")
659
+ return None, None
660
+
661
+ def _apply_lora(self, model: Any, config: TrainingConfig) -> Any:
662
+ """Apply LoRA layers to the model."""
663
+ if LoRALinear is None:
664
+ # Fallback: Simple LoRA implementation
665
+ logger.warning("mlx_lm LoRA not available, using basic implementation")
666
+ return self._apply_basic_lora(model, config)
667
+
668
+ # Use mlx_lm's LoRA implementation
669
+ lora_layers = 0
670
+
671
+ def apply_lora_to_linear(layer):
672
+ nonlocal lora_layers
673
+ if isinstance(layer, nn.Linear):
674
+ # Check if this is a target module
675
+ for target in config.target_modules:
676
+ if hasattr(layer, '__name__') and target in str(layer.__name__):
677
+ # Replace with LoRA layer
678
+ lora_layers += 1
679
+ return LoRALinear(
680
+ in_features=layer.weight.shape[1],
681
+ out_features=layer.weight.shape[0],
682
+ r=config.lora_r,
683
+ alpha=config.lora_alpha,
684
+ dropout=config.lora_dropout
685
+ )
686
+ return layer
687
+ return layer
688
+
689
+ # Apply LoRA to all linear layers in target modules
690
+ model = tree_map(apply_lora_to_linear, model)
691
+ logger.info(f"Applied LoRA to {lora_layers} layers")
692
+
693
+ return model
694
+
695
+ def _apply_basic_lora(self, model: Any, config: TrainingConfig) -> Any:
696
+ """Apply basic LoRA implementation."""
697
+ class BasicLoRALinear(nn.Module):
698
+ def __init__(self, linear_layer, r=16, alpha=32):
699
+ super().__init__()
700
+ self.linear = linear_layer
701
+ self.r = r
702
+ self.alpha = alpha
703
+
704
+ # LoRA parameters
705
+ in_features = linear_layer.weight.shape[1]
706
+ out_features = linear_layer.weight.shape[0]
707
+
708
+ # Low-rank matrices
709
+ self.lora_a = mx.random.normal((r, in_features)) * 0.01
710
+ self.lora_b = mx.zeros((out_features, r))
711
+
712
+ # Scaling factor
713
+ self.scaling = alpha / r
714
+
715
+ def __call__(self, x):
716
+ # Original forward pass
717
+ result = self.linear(x)
718
+
719
+ # Add LoRA contribution
720
+ lora_out = x @ self.lora_a.T @ self.lora_b.T * self.scaling
721
+
722
+ return result + lora_out
723
+
724
+ # Apply to target modules
725
+ def apply_basic_lora_to_layer(layer):
726
+ if isinstance(layer, nn.Linear):
727
+ return BasicLoRALinear(layer, r=config.lora_r, alpha=config.lora_alpha)
728
+ return layer
729
+
730
+ model = tree_map(apply_basic_lora_to_layer, model)
731
+ return model
732
+
733
+ def _load_dataset(self, dataset_path: Path, tokenizer: Any, config: TrainingConfig) -> Optional[Any]:
734
+ """Load and prepare the dataset."""
735
+ try:
736
+ # Load JSONL dataset
737
+ examples = []
738
+ with open(dataset_path, 'r') as f:
739
+ for line in f:
740
+ data = json.loads(line.strip())
741
+ examples.append(data)
742
+
743
+ # Tokenize examples
744
+ tokenized_examples = []
745
+ max_seq_len = getattr(config, "max_sequence_length", None)
746
+ for example in examples:
747
+ # Format as conversation
748
+ if 'prompt' in example and 'response' in example:
749
+ text = f"User: {example['prompt']}\nAssistant: {example['response']}"
750
+ elif 'text' in example:
751
+ text = example['text']
752
+ else:
753
+ continue
754
+
755
+ # Tokenize
756
+ tokens = tokenizer.encode(text)
757
+ if max_seq_len and len(tokens) > max_seq_len:
758
+ tokens = tokens[:max_seq_len]
759
+ tokenized_examples.append({
760
+ 'input_ids': mx.array(tokens),
761
+ 'labels': mx.array(tokens) # For causal LM
762
+ })
763
+
764
+ logger.info(f"Loaded {len(tokenized_examples)} training examples")
765
+ return tokenized_examples
766
+
767
+ except Exception as e:
768
+ logger.error(f"Error loading dataset: {e}")
769
+ return None
770
+
771
+ def _setup_optimizer(self, model: Any, config: TrainingConfig) -> Any:
772
+ """Setup the optimizer."""
773
+ # Get trainable parameters (LoRA parameters only)
774
+ trainable_params = []
775
+
776
+ def get_lora_params(module, prefix=""):
777
+ # Check for LoRA parameters
778
+ if hasattr(module, 'lora_a'):
779
+ trainable_params.append(module.lora_a)
780
+ if hasattr(module, 'lora_b'):
781
+ trainable_params.append(module.lora_b)
782
+
783
+ # Try to iterate over child modules
784
+ try:
785
+ # Try vars() first (for regular Python objects)
786
+ children = vars(module).items()
787
+ except TypeError:
788
+ # If vars() doesn't work, try __dict__ directly
789
+ if hasattr(module, '__dict__'):
790
+ children = module.__dict__.items()
791
+ else:
792
+ # For MLX modules, try to get children differently
793
+ children = []
794
+ if hasattr(module, 'children'):
795
+ for child in module.children():
796
+ children.append(('', child))
797
+
798
+ for name, child in children:
799
+ if isinstance(child, nn.Module):
800
+ get_lora_params(child, f"{prefix}.{name}")
801
+
802
+ # Only try to extract LoRA params if model is a Module
803
+ if isinstance(model, nn.Module):
804
+ get_lora_params(model)
805
+
806
+ if not trainable_params:
807
+ # If no LoRA params found, train all parameters (fallback)
808
+ logger.warning("No LoRA parameters found, training all parameters")
809
+ # For MLX models, we need to get parameters differently
810
+ if hasattr(model, 'parameters'):
811
+ trainable_params = list(model.parameters())
812
+ else:
813
+ logger.error("Model has no parameters() method")
814
+ trainable_params = []
815
+
816
+ # Create optimizer
817
+ optimizer = optim.AdamW(
818
+ learning_rate=config.learning_rate,
819
+ weight_decay=config.weight_decay
820
+ )
821
+
822
+ # Initialize optimizer state
823
+ optimizer.init(trainable_params)
824
+
825
+ logger.info(f"Initialized optimizer with {len(trainable_params)} trainable parameters")
826
+ return optimizer
827
+
828
+ def _training_loop(
829
+ self,
830
+ model: Any,
831
+ dataset: list,
832
+ optimizer: Any,
833
+ config: TrainingConfig,
834
+ tokenizer: Any,
835
+ progress_callback: Optional[Callable] = None
836
+ ) -> Any:
837
+ """Main training loop."""
838
+ model.train()
839
+
840
+ total_steps = len(dataset) * config.epochs
841
+ current_step = 0
842
+
843
+ for epoch in range(config.epochs):
844
+ epoch_loss = 0.0
845
+ batch_loss = 0.0
846
+
847
+ for i, batch in enumerate(dataset):
848
+ # Forward pass
849
+ input_ids = batch['input_ids']
850
+ labels = batch['labels']
851
+
852
+ # Compute loss
853
+ logits = model(input_ids[None, :]) # Add batch dimension
854
+
855
+ # Cross-entropy loss
856
+ loss = mx.mean(
857
+ nn.losses.cross_entropy(
858
+ logits[0, :-1], # All but last prediction
859
+ labels[1:], # All but first token
860
+ reduction='none'
861
+ )
862
+ )
863
+
864
+ # Backward pass
865
+ loss_value, grads = mx.value_and_grad(lambda m: loss)(model)
866
+
867
+ # Gradient accumulation
868
+ batch_loss += loss_value.item()
869
+
870
+ if (i + 1) % config.gradient_accumulation_steps == 0:
871
+ # Update weights
872
+ optimizer.update(model, grads)
873
+
874
+ # Clear accumulated loss
875
+ avg_loss = batch_loss / config.gradient_accumulation_steps
876
+ epoch_loss += avg_loss
877
+ batch_loss = 0.0
878
+
879
+ # Progress callback
880
+ if progress_callback:
881
+ progress_callback(epoch, i, avg_loss)
882
+
883
+ current_step += 1
884
+
885
+ # Evaluate to ensure computation
886
+ mx.eval(model.parameters())
887
+
888
+ # Log epoch statistics
889
+ avg_epoch_loss = epoch_loss / (len(dataset) / config.gradient_accumulation_steps)
890
+ logger.info(f"Epoch {epoch+1}/{config.epochs} - Loss: {avg_epoch_loss:.4f}")
891
+
892
+ return model
893
+
894
+ def _save_model(
895
+ self,
896
+ model: Any,
897
+ tokenizer: Any,
898
+ output_name: str,
899
+ base_model_name: str
900
+ ) -> bool:
901
+ """Save the fine-tuned model."""
902
+ try:
903
+ # Create output directory in MLX models folder for consistency
904
+ output_dir = Path.home() / ".cortex" / "mlx_models" / output_name
905
+ output_dir.mkdir(parents=True, exist_ok=True)
906
+
907
+ # Save model weights
908
+ weights_path = output_dir / "model.safetensors"
909
+
910
+ # Get model state dict
911
+ state_dict = {}
912
+
913
+ def extract_weights(module, prefix=""):
914
+ for name, param in vars(module).items():
915
+ if isinstance(param, mx.array):
916
+ state_dict[f"{prefix}.{name}"] = param
917
+ elif isinstance(param, nn.Module):
918
+ extract_weights(param, f"{prefix}.{name}")
919
+
920
+ extract_weights(model)
921
+
922
+ # Save using safetensors format (or numpy for simplicity)
923
+ import numpy as np
924
+ np_state_dict = {k: v.tolist() for k, v in state_dict.items()}
925
+
926
+ with open(weights_path, 'w') as f:
927
+ json.dump(np_state_dict, f)
928
+
929
+ # Save tokenizer
930
+ if hasattr(tokenizer, 'save_pretrained'):
931
+ tokenizer.save_pretrained(output_dir)
932
+
933
+ # Save config
934
+ config_data = {
935
+ "base_model": base_model_name,
936
+ "model_type": "fine-tuned",
937
+ "fine_tuning_method": "LoRA",
938
+ "created_at": time.strftime("%Y-%m-%d %H:%M:%S")
939
+ }
940
+
941
+ with open(output_dir / "config.json", 'w') as f:
942
+ json.dump(config_data, f, indent=2)
943
+
944
+ # Copy any additional files from base model
945
+ base_model_path = Path.home() / ".cortex" / "models" / base_model_name
946
+ if base_model_path.exists():
947
+ for file in ['tokenizer_config.json', 'special_tokens_map.json', 'vocab.json']:
948
+ src = base_model_path / file
949
+ if src.exists():
950
+ shutil.copy2(src, output_dir / file)
951
+
952
+ logger.info(f"Model saved to {output_dir}")
953
+ return True
954
+
955
+ except Exception as e:
956
+ logger.error(f"Error saving model: {e}")
957
+ return False