cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,502 @@
1
+ """MLX LoRA trainer using mlx_lm implementation."""
2
+
3
+ import logging
4
+ import json
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional, Dict, Any, Callable
8
+ from dataclasses import dataclass
9
+ import shutil
10
+ import math
11
+
12
+ try:
13
+ import mlx.core as mx
14
+ import mlx.nn as nn
15
+ from mlx_lm import load as mlx_load
16
+ from mlx_lm.tuner.utils import linear_to_lora_layers
17
+ from mlx_lm.tuner.datasets import load_dataset as mlx_load_dataset, CacheDataset
18
+ from mlx_lm.tuner.trainer import TrainingArgs, train, evaluate, TrainingCallback
19
+ import mlx.optimizers as optim
20
+ MLX_AVAILABLE = True
21
+ except Exception as exc: # noqa: BLE001
22
+ # Keep the module importable when MLX/metal is missing so we can show a clear message.
23
+ MLX_AVAILABLE = False
24
+ mx = nn = mlx_load = linear_to_lora_layers = mlx_load_dataset = CacheDataset = TrainingArgs = train = evaluate = TrainingCallback = optim = None # type: ignore
25
+ _MLX_IMPORT_ERROR = exc
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class LoRAConfig:
32
+ """Configuration for LoRA fine-tuning."""
33
+ rank: int = 8
34
+ alpha: float = 16.0
35
+ dropout: float = 0.0
36
+ target_modules: list = None
37
+
38
+ def __post_init__(self):
39
+ if self.target_modules is None:
40
+ self.target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
41
+
42
+
43
+ class MLXLoRATrainer:
44
+ """LoRA trainer using mlx_lm's implementation."""
45
+
46
+ def __init__(self, model_manager, config):
47
+ """Initialize the trainer."""
48
+ self.model_manager = model_manager
49
+ self.config = config
50
+
51
+ @staticmethod
52
+ def is_available() -> bool:
53
+ """Return True when MLX/Metal stack is importable."""
54
+ return MLX_AVAILABLE
55
+
56
+ def train(
57
+ self,
58
+ base_model_name: str,
59
+ dataset_path: Path,
60
+ output_name: str,
61
+ training_config: Any,
62
+ progress_callback: Optional[Callable] = None
63
+ ) -> bool:
64
+ """
65
+ Train a model using LoRA with mlx_lm.
66
+
67
+ Args:
68
+ base_model_name: Name of the base model to fine-tune
69
+ dataset_path: Path to the training dataset
70
+ output_name: Name for the fine-tuned model
71
+ training_config: Training configuration
72
+ progress_callback: Optional callback for progress updates
73
+
74
+ Returns:
75
+ True if training succeeded, False otherwise
76
+ """
77
+ try:
78
+ if not MLX_AVAILABLE:
79
+ logger.error("MLX is not available; fine-tuning requires MLX.")
80
+ if "_MLX_IMPORT_ERROR" in globals():
81
+ logger.debug(f"MLX import error: {_MLX_IMPORT_ERROR}") # type: ignore[name-defined]
82
+ return False
83
+
84
+ logger.info(f"Starting MLX LoRA training: {base_model_name} -> {output_name}")
85
+
86
+ # Get the model path
87
+ model_path = self._get_model_path(base_model_name)
88
+ if not model_path:
89
+ logger.error(f"Could not find model path for {base_model_name}")
90
+ return False
91
+
92
+ # Try to reuse an already loaded model to avoid a second full load in unified memory.
93
+ model = None
94
+ tokenizer = None
95
+ if self.model_manager:
96
+ cache_keys = [
97
+ str(model_path),
98
+ base_model_name,
99
+ getattr(self.model_manager, "current_model", None),
100
+ ]
101
+ for key in cache_keys:
102
+ if not key:
103
+ continue
104
+ if model is None:
105
+ model = self.model_manager.model_cache.get(key)
106
+ if tokenizer is None:
107
+ tokenizer = self.model_manager.tokenizers.get(key)
108
+ if model is not None and tokenizer is not None:
109
+ logger.info(f"Reusing loaded model from cache (key: {key})")
110
+ break
111
+
112
+ # Load model and tokenizer using mlx_lm if not already cached
113
+ if model is None or tokenizer is None:
114
+ logger.info(f"Loading model from {model_path}")
115
+ model, tokenizer = mlx_load(str(model_path))
116
+
117
+ # Apply LoRA layers
118
+ logger.info(f"Applying LoRA with rank={training_config.lora_r}")
119
+ lora_config = LoRAConfig(
120
+ rank=training_config.lora_r,
121
+ alpha=training_config.lora_alpha,
122
+ dropout=training_config.lora_dropout,
123
+ target_modules=training_config.target_modules
124
+ )
125
+
126
+ # Convert linear layers to LoRA layers
127
+ # Note: linear_to_lora_layers modifies the model in-place and returns None
128
+ linear_to_lora_layers(
129
+ model,
130
+ num_layers=training_config.num_lora_layers if hasattr(training_config, 'num_lora_layers') else 16,
131
+ config={
132
+ "rank": lora_config.rank,
133
+ "dropout": lora_config.dropout,
134
+ "scale": lora_config.alpha / lora_config.rank
135
+ }
136
+ )
137
+
138
+ # Model freezing is handled automatically by linear_to_lora_layers
139
+ # Only LoRA parameters will be trainable
140
+
141
+ # Load dataset
142
+ logger.info(f"Loading dataset from {dataset_path}")
143
+ train_data = self._load_dataset(dataset_path, tokenizer, training_config)
144
+
145
+ if not train_data:
146
+ logger.error("Failed to load dataset")
147
+ return False
148
+
149
+ # Get dataset length properly
150
+ if hasattr(train_data, '__len__'):
151
+ dataset_len = len(train_data)
152
+ elif hasattr(train_data, 'data') and hasattr(train_data.data, '__len__'):
153
+ dataset_len = len(train_data.data)
154
+ else:
155
+ dataset_len = 1 # Fallback
156
+
157
+ logger.info(f"Dataset contains {dataset_len} examples")
158
+
159
+ # Setup training arguments
160
+ adapter_file = str(Path.home() / ".cortex" / "adapters" / output_name / "adapter.safetensors")
161
+ Path(adapter_file).parent.mkdir(parents=True, exist_ok=True)
162
+
163
+ # Calculate iterations: total examples / (effective batch) * epochs
164
+ effective_batch = max(1, training_config.batch_size) * max(
165
+ 1, getattr(training_config, "gradient_accumulation_steps", 1)
166
+ )
167
+ num_iters = max(1, math.ceil((dataset_len * training_config.epochs) / effective_batch))
168
+
169
+ training_args = TrainingArgs(
170
+ batch_size=training_config.batch_size,
171
+ iters=num_iters,
172
+ steps_per_report=10,
173
+ # Avoid extra evaluation passes for small datasets by setting eval steps beyond total iters
174
+ steps_per_eval=num_iters + 1,
175
+ val_batches=1, # Just 1 validation batch
176
+ steps_per_save=100,
177
+ adapter_file=adapter_file,
178
+ grad_checkpoint=training_config.gradient_checkpointing if hasattr(training_config, 'gradient_checkpointing') else False,
179
+ )
180
+
181
+ # Setup optimizer with learning rate
182
+ optimizer = optim.AdamW(
183
+ learning_rate=training_config.learning_rate,
184
+ weight_decay=training_config.weight_decay if hasattr(training_config, 'weight_decay') else 0.01
185
+ )
186
+
187
+ # Create a simple progress tracker
188
+ class ProgressTracker(TrainingCallback):
189
+ def __init__(self, callback, dataset_len, epochs):
190
+ self.callback = callback
191
+ self.total_iters = training_args.iters
192
+ self.steps_per_epoch = max(1, dataset_len // training_args.batch_size)
193
+ self.epochs = epochs
194
+
195
+ def on_train_loss_report(self, train_info: dict):
196
+ """Called when training loss is reported."""
197
+ if self.callback:
198
+ iteration = train_info.get('iteration', 0)
199
+ loss = train_info.get('train_loss', 0.0)
200
+ # MLX iterations start at 1, not 0, so adjust
201
+ actual_iter = iteration - 1
202
+ # Calculate epoch based on actual iteration
203
+ epoch = actual_iter // self.steps_per_epoch
204
+ step = actual_iter % self.steps_per_epoch
205
+ # Ensure epoch doesn't exceed total epochs
206
+ epoch = min(epoch, self.epochs - 1)
207
+ self.callback(epoch, step, loss)
208
+
209
+ tracker = ProgressTracker(progress_callback, dataset_len, training_config.epochs) if progress_callback else None
210
+
211
+ # Prepare validation dataset
212
+ # For MLX training, we always need a validation dataset (can't be None)
213
+ # For small datasets, we'll use the same data for validation
214
+ val_data = train_data # Default to using training data for validation
215
+ logger.info("Using training data for validation (small dataset)")
216
+
217
+ # Training loop
218
+ logger.info("Starting training...")
219
+ # Note: train() doesn't return anything, it modifies model in-place and saves weights
220
+ train(
221
+ model,
222
+ optimizer,
223
+ train_dataset=train_data,
224
+ val_dataset=val_data, # Use proper validation dataset or None
225
+ args=training_args,
226
+ training_callback=tracker
227
+ )
228
+
229
+ # Save the fine-tuned model
230
+ logger.info(f"Saving fine-tuned model to {output_name}")
231
+ adapter_dir = Path(training_args.adapter_file).parent
232
+ success = self._save_model(
233
+ model=model,
234
+ tokenizer=tokenizer,
235
+ output_name=output_name,
236
+ base_model_name=base_model_name,
237
+ adapter_path=str(adapter_dir)
238
+ )
239
+
240
+ if success:
241
+ logger.info(f"Successfully saved fine-tuned model as {output_name}")
242
+ # Clean up training checkpoints after successful save
243
+ self._cleanup_checkpoints(adapter_dir)
244
+ return True
245
+ else:
246
+ logger.error("Failed to save fine-tuned model")
247
+ return False
248
+
249
+ except KeyboardInterrupt:
250
+ logger.info("Training interrupted by user")
251
+ print("\n\033[93m⚠\033[0m Training interrupted by user")
252
+ return False
253
+ except Exception as e:
254
+ logger.error(f"Training failed: {e}")
255
+ print(f"\n\033[31m✗\033[0m Training error: {str(e)}")
256
+ import traceback
257
+ traceback.print_exc()
258
+ return False
259
+
260
+ def _get_model_path(self, model_name: str) -> Optional[Path]:
261
+ """Get the path to the model, prioritizing MLX models."""
262
+ # First check if it's already an MLX model (converted or fine-tuned)
263
+ mlx_path = Path.home() / ".cortex" / "mlx_models" / model_name
264
+ if mlx_path.exists():
265
+ logger.info(f"Found MLX model at: {mlx_path}")
266
+ return mlx_path
267
+
268
+ # Check in models directory
269
+ models_path = Path.home() / ".cortex" / "models" / model_name
270
+ if models_path.exists():
271
+ logger.info(f"Found model at: {models_path}")
272
+ return models_path
273
+
274
+ # Check in configured models directory (most common location)
275
+ if self.model_manager and self.model_manager.config:
276
+ try:
277
+ config_model_path = Path(self.model_manager.config.model.model_path).expanduser().resolve()
278
+ config_path = config_model_path / model_name
279
+ if config_path.exists():
280
+ logger.info(f"Found model in configured path: {config_path}")
281
+ return config_path
282
+ except Exception as e:
283
+ logger.debug(f"Could not check configured model path: {e}")
284
+
285
+ # Check if it's a full path
286
+ if Path(model_name).exists():
287
+ full_path = Path(model_name).resolve()
288
+ logger.info(f"Found model at full path: {full_path}")
289
+ return full_path
290
+
291
+ # Last resort: check if it's a relative path in current directory
292
+ current_path = Path.cwd() / model_name
293
+ if current_path.exists():
294
+ logger.info(f"Found model at current directory: {current_path}")
295
+ return current_path
296
+
297
+ logger.error(f"Model not found: {model_name}")
298
+ return None
299
+
300
+ def _load_dataset(self, dataset_path: Path, tokenizer: Any, training_config: Any) -> Optional[Any]:
301
+ """Load and prepare the dataset."""
302
+ try:
303
+ from mlx_lm.tuner.datasets import CacheDataset, TextDataset
304
+
305
+ # Load JSONL dataset
306
+ examples = []
307
+ with open(dataset_path, 'r') as f:
308
+ for line in f:
309
+ data = json.loads(line.strip())
310
+ examples.append(data)
311
+
312
+ # Check data format and create appropriate dataset
313
+ if not examples:
314
+ logger.error("No examples found in dataset")
315
+ return None
316
+
317
+ sample = examples[0]
318
+
319
+ # Convert all formats to text format for simplicity
320
+ # This avoids issues with tokenizers that don't have chat templates
321
+ text_examples = []
322
+ max_seq_len = getattr(training_config, "max_sequence_length", None)
323
+ # crude char-level guard to avoid very long sequences; token-level truncation happens in tokenizer
324
+ max_chars = max_seq_len * 4 if max_seq_len else None
325
+ for example in examples:
326
+ if 'prompt' in example and 'response' in example:
327
+ # Format as a simple conversation
328
+ text = f"User: {example['prompt']}\n\nAssistant: {example['response']}"
329
+ elif 'prompt' in example and 'completion' in example:
330
+ text = f"User: {example['prompt']}\n\nAssistant: {example['completion']}"
331
+ elif 'text' in example:
332
+ text = example['text']
333
+ else:
334
+ logger.warning(f"Skipping example with unsupported format: {example}")
335
+ continue
336
+ if max_chars and len(text) > max_chars:
337
+ text = text[:max_chars]
338
+ text_examples.append({'text': text})
339
+
340
+ if not text_examples:
341
+ logger.error("No valid examples found after conversion")
342
+ return None
343
+
344
+ # Create TextDataset which just uses tokenizer.encode()
345
+ dataset = TextDataset(
346
+ data=text_examples,
347
+ tokenizer=tokenizer,
348
+ text_key='text'
349
+ )
350
+
351
+ # Wrap with CacheDataset for efficiency
352
+ cached_dataset = CacheDataset(dataset)
353
+
354
+ logger.info(f"Loaded {len(text_examples)} training examples")
355
+ return cached_dataset
356
+
357
+ except ImportError as e:
358
+ logger.error(f"Required dataset classes not available: {e}")
359
+ return None
360
+ except FileNotFoundError:
361
+ logger.error(f"Dataset file not found: {dataset_path}")
362
+ return None
363
+ except json.JSONDecodeError as e:
364
+ logger.error(f"Invalid JSON in dataset: {e}")
365
+ return None
366
+ except Exception as e:
367
+ logger.error(f"Error loading dataset: {e}")
368
+ import traceback
369
+ traceback.print_exc()
370
+ return None
371
+
372
+ def _save_model(
373
+ self,
374
+ model: Any,
375
+ tokenizer: Any,
376
+ output_name: str,
377
+ base_model_name: str,
378
+ adapter_path: str
379
+ ) -> bool:
380
+ """Save the fine-tuned model with integrated LoRA weights."""
381
+ try:
382
+ # Always save to MLX models directory for consistent loading
383
+ output_dir = Path.home() / ".cortex" / "mlx_models" / output_name
384
+ output_dir.mkdir(parents=True, exist_ok=True)
385
+
386
+ # Get base model path
387
+ base_model_path = self._get_model_path(base_model_name)
388
+ if not base_model_path or not base_model_path.exists():
389
+ logger.error(f"Base model path not found: {base_model_name}")
390
+ return False
391
+
392
+ logger.info(f"Saving fine-tuned model to {output_dir}")
393
+
394
+ # Copy base model files and add adapter
395
+ # Note: mlx_lm doesn't have a save function, the adapter is saved separately by train()
396
+ logger.info(f"Copying base model files from {base_model_path} to {output_dir}")
397
+ for file in base_model_path.glob("*"):
398
+ if file.is_file():
399
+ shutil.copy2(file, output_dir / file.name)
400
+ elif file.is_dir():
401
+ shutil.copytree(file, output_dir / file.name, dirs_exist_ok=True)
402
+
403
+ # Copy adapter files (only the final adapter, not checkpoints)
404
+ adapter_path = Path(adapter_path)
405
+ if adapter_path.exists():
406
+ for adapter_file in adapter_path.glob("*.safetensors"):
407
+ # Skip checkpoint files (e.g., 0000100_adapters.safetensors)
408
+ if adapter_file.name.endswith('_adapters.safetensors'):
409
+ logger.debug(f"Skipping checkpoint: {adapter_file.name}")
410
+ continue
411
+ logger.info(f"Copying adapter: {adapter_file.name}")
412
+ shutil.copy2(adapter_file, output_dir / adapter_file.name)
413
+
414
+ if (adapter_path / "adapter_config.json").exists():
415
+ shutil.copy2(adapter_path / "adapter_config.json", output_dir / "adapter_config.json")
416
+
417
+ # Update config to mark as fine-tuned
418
+ config_path = output_dir / "config.json"
419
+ if config_path.exists():
420
+ with open(config_path, 'r') as f:
421
+ config = json.load(f)
422
+
423
+ # Add fine-tuning metadata
424
+ config["fine_tuned"] = True
425
+ config["base_model"] = base_model_name
426
+ config["fine_tuning_method"] = "LoRA"
427
+ config["lora_adapter"] = True
428
+ config["created_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
429
+
430
+ with open(config_path, 'w') as f:
431
+ json.dump(config, f, indent=2)
432
+
433
+ # Create a marker file for proper detection
434
+ with open(output_dir / "fine_tuned.marker", 'w') as f:
435
+ f.write(f"LoRA fine-tuned version of {base_model_name}\n")
436
+ f.write(f"Created: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
437
+ f.write(f"Adapter path: {adapter_path}\n")
438
+ f.write(f"Output directory: {output_dir}\n")
439
+
440
+ logger.info(f"Fine-tuned model successfully saved to {output_dir}")
441
+ return True
442
+
443
+ except Exception as e:
444
+ logger.error(f"Error saving model: {e}")
445
+ import traceback
446
+ traceback.print_exc()
447
+ return False
448
+
449
+ def _cleanup_checkpoints(self, adapter_dir: Path) -> None:
450
+ """
451
+ Clean up training checkpoint files after successful training.
452
+
453
+ Checkpoints are intermediate saves during training (e.g., 0000100_adapters.safetensors).
454
+ We keep them during training for crash recovery but delete after successful completion.
455
+
456
+ Args:
457
+ adapter_dir: Directory containing adapter files and checkpoints
458
+ """
459
+ try:
460
+ if not adapter_dir.exists():
461
+ return
462
+
463
+ checkpoint_files = []
464
+ total_size = 0
465
+
466
+ # Find all checkpoint files (pattern: NNNNNNN_adapters.safetensors)
467
+ for file in adapter_dir.glob("*_adapters.safetensors"):
468
+ # Check if filename matches checkpoint pattern (digits followed by _adapters.safetensors)
469
+ filename = file.name
470
+ if filename.endswith("_adapters.safetensors"):
471
+ # Extract the prefix before _adapters
472
+ prefix = filename[:-len("_adapters.safetensors")]
473
+ # Check if prefix is all digits (checkpoint pattern)
474
+ if prefix.isdigit():
475
+ checkpoint_files.append(file)
476
+ total_size += file.stat().st_size
477
+
478
+ if checkpoint_files:
479
+ # Convert size to human-readable format
480
+ size_gb = total_size / (1024 ** 3)
481
+ size_str = f"{size_gb:.2f}GB" if size_gb >= 1 else f"{total_size / (1024 ** 2):.1f}MB"
482
+
483
+ logger.info(f"Cleaning up {len(checkpoint_files)} training checkpoints ({size_str})")
484
+
485
+ # Delete checkpoint files
486
+ for checkpoint in checkpoint_files:
487
+ try:
488
+ checkpoint.unlink()
489
+ logger.debug(f"Deleted checkpoint: {checkpoint.name}")
490
+ except Exception as e:
491
+ logger.warning(f"Failed to delete checkpoint {checkpoint.name}: {e}")
492
+
493
+ logger.info(f"✓ Freed {size_str} by removing training checkpoints")
494
+ print(f"\033[92m✓\033[0m Cleaned up {len(checkpoint_files)} training checkpoints ({size_str})")
495
+ else:
496
+ logger.debug("No checkpoint files to clean up")
497
+
498
+ except Exception as e:
499
+ # Don't fail the training if cleanup fails, just log the error
500
+ logger.warning(f"Checkpoint cleanup failed (non-critical): {e}")
501
+ # Still inform user that training succeeded but cleanup had issues
502
+ print(f"\033[93m⚠\033[0m Training succeeded but checkpoint cleanup encountered issues: {e}")