cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,707 @@
1
+ """Interactive fine-tuning wizard for Cortex."""
2
+
3
+ import os
4
+ import sys
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Optional, Dict, Any, List, Tuple
8
+ import json
9
+ import time
10
+
11
+ from cortex.model_manager import ModelManager, ModelFormat
12
+ from cortex.config import Config
13
+ from .trainer import LoRATrainer, TrainingConfig, SmartConfigFactory
14
+
15
+ from .mlx_lora_trainer import MLXLoRATrainer
16
+ from .dataset import DatasetPreparer
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class FineTuneWizard:
22
+ """Interactive wizard for fine-tuning models - Cortex style."""
23
+
24
+ def __init__(self, model_manager: ModelManager, config: Config):
25
+ """Initialize the fine-tuning wizard."""
26
+ self.model_manager = model_manager
27
+ self.config = config
28
+ self.trainer = None
29
+ self.dataset_preparer = DatasetPreparer()
30
+ self.cli = None # Will be set by CLI when running
31
+
32
+ def get_terminal_width(self) -> int:
33
+ """Get terminal width."""
34
+ if self.cli:
35
+ return self.cli.get_terminal_width()
36
+ return 80
37
+
38
+ def start(self) -> Tuple[bool, str]:
39
+ """Start the interactive fine-tuning experience."""
40
+
41
+ try:
42
+ # Hard block if MLX is not installed/available. Fine-tuning depends on it.
43
+ if not MLXLoRATrainer.is_available():
44
+ message = "Fine-tuning requires MLX/Metal, but the MLX stack is not available in this environment."
45
+ print(f"\n\033[31m✗\033[0m {message}")
46
+ return False, message
47
+ # Step 1: Select base model
48
+ base_model = self._select_base_model()
49
+ if not base_model:
50
+ return False, "Fine-tuning cancelled"
51
+
52
+ # Step 2: Select or prepare dataset
53
+ dataset_path = self._prepare_dataset()
54
+ if not dataset_path:
55
+ return False, "Fine-tuning cancelled"
56
+
57
+ # Step 3: Configure training settings
58
+ training_config = self._configure_training(base_model, dataset_path)
59
+ if not training_config:
60
+ return False, "Fine-tuning cancelled"
61
+
62
+ # Step 4: Choose output name
63
+ output_name = self._get_output_name(base_model)
64
+ if not output_name:
65
+ return False, "Fine-tuning cancelled"
66
+
67
+ # Step 5: Confirm and start training
68
+ if not self._confirm_settings(base_model, dataset_path, training_config, output_name):
69
+ return False, "Fine-tuning cancelled"
70
+
71
+ # Step 6: Run training
72
+ success = self._run_training(base_model, dataset_path, training_config, output_name)
73
+
74
+ if success:
75
+ return True, f"Fine-tuned model saved as: {output_name}"
76
+ else:
77
+ return False, "Training failed"
78
+
79
+ except KeyboardInterrupt:
80
+ print("\n\033[93m⚠\033[0m Fine-tuning cancelled by user")
81
+ return False, "Fine-tuning cancelled"
82
+ except FileNotFoundError as e:
83
+ logger.error(f"File not found: {e}")
84
+ print(f"\n\033[31m✗\033[0m File not found: {e}")
85
+ return False, f"File not found: {str(e)}"
86
+ except PermissionError as e:
87
+ logger.error(f"Permission denied: {e}")
88
+ print(f"\n\033[31m✗\033[0m Permission denied: {e}")
89
+ return False, f"Permission denied: {str(e)}"
90
+ except Exception as e:
91
+ logger.error(f"Fine-tuning failed: {e}")
92
+ print(f"\n\033[31m✗\033[0m Unexpected error: {e}")
93
+ import traceback
94
+ traceback.print_exc()
95
+ return False, f"Fine-tuning failed: {str(e)}"
96
+
97
+ def _select_base_model(self) -> Optional[str]:
98
+ """Select the base model to fine-tune."""
99
+ width = min(self.get_terminal_width() - 2, 70)
100
+
101
+ # Get available models
102
+ models = self._get_available_models()
103
+
104
+ if not models:
105
+ print("\033[31m✗\033[0m No models available. Use \033[93m/download\033[0m to get models.")
106
+ return None
107
+
108
+ # Check if a model is already loaded
109
+ if self.model_manager.current_model:
110
+ current_model_name = self.model_manager.current_model
111
+
112
+ # Create dialog box for current model
113
+ print()
114
+ self.cli.print_box_header("Fine-Tuning Setup", width)
115
+ self.cli.print_empty_line(width)
116
+
117
+ self.cli.print_box_line(f" \033[96mCurrent Model:\033[0m \033[93m{current_model_name}\033[0m", width)
118
+
119
+ self.cli.print_empty_line(width)
120
+ self.cli.print_box_separator(width)
121
+ self.cli.print_empty_line(width)
122
+
123
+ self.cli.print_box_line(" Use this model for fine-tuning?", width)
124
+ self.cli.print_empty_line(width)
125
+ self.cli.print_box_line(" \033[93m[Y]\033[0m Yes, use this model", width)
126
+ self.cli.print_box_line(" \033[93m[N]\033[0m No, select another", width)
127
+
128
+ self.cli.print_empty_line(width)
129
+ self.cli.print_box_footer(width)
130
+
131
+ choice = input("\n\033[96m▶\033[0m Choice (\033[93my\033[0m/\033[2mn\033[0m): ").strip().lower()
132
+
133
+ if choice in ['y', 'yes', '']:
134
+ print(f"\033[32m✓\033[0m Using: {current_model_name}")
135
+ return current_model_name
136
+
137
+ # Show model selection dialog
138
+ print()
139
+ self.cli.print_box_header("Select Base Model", width)
140
+ self.cli.print_empty_line(width)
141
+
142
+ # List models with numbers
143
+ for i, (name, info) in enumerate(models[:10], 1):
144
+ size_str = f"{info['size_gb']:.1f}GB"
145
+ format_str = info['format']
146
+ line = f" \033[93m[{i}]\033[0m {name} \033[2m({size_str}, {format_str})\033[0m"
147
+ self.cli.print_box_line(line, width)
148
+
149
+ if len(models) > 10:
150
+ self.cli.print_empty_line(width)
151
+ self.cli.print_box_line(f" \033[2m... and {len(models) - 10} more models available\033[0m", width)
152
+
153
+ self.cli.print_empty_line(width)
154
+ self.cli.print_box_footer(width)
155
+
156
+ # Get user selection
157
+ choice = self.cli.get_input_with_escape(f"Select model (1-{len(models)})")
158
+
159
+ if choice is None:
160
+ return None
161
+
162
+ try:
163
+ idx = int(choice) - 1
164
+ if 0 <= idx < len(models):
165
+ selected_model = models[idx][0]
166
+ print(f"\033[32m✓\033[0m Selected: {selected_model}")
167
+ return selected_model
168
+ else:
169
+ print("\033[31m✗\033[0m Invalid selection")
170
+ return None
171
+ except ValueError:
172
+ print("\033[31m✗\033[0m Please enter a valid number")
173
+ return None
174
+
175
+ def _prepare_dataset(self) -> Optional[Path]:
176
+ """Prepare the training dataset."""
177
+ width = min(self.get_terminal_width() - 2, 70)
178
+
179
+ # Show dataset options dialog
180
+ print()
181
+ self.cli.print_box_header("Training Data", width)
182
+ self.cli.print_empty_line(width)
183
+
184
+ self.cli.print_box_line(" \033[96mSelect data source:\033[0m", width)
185
+ self.cli.print_empty_line(width)
186
+
187
+ self.cli.print_box_line(" \033[93m[1]\033[0m Load from file \033[2m(JSONL/CSV/TXT)\033[0m", width)
188
+ self.cli.print_box_line(" \033[93m[2]\033[0m Create interactively", width)
189
+ self.cli.print_box_line(" \033[93m[3]\033[0m Use sample dataset \033[2m(for testing)\033[0m", width)
190
+
191
+ self.cli.print_empty_line(width)
192
+ self.cli.print_box_footer(width)
193
+
194
+ choice = self.cli.get_input_with_escape("Select option (1-3)")
195
+
196
+ if choice is None:
197
+ return None
198
+
199
+ if choice == "1":
200
+ return self._load_existing_dataset()
201
+ elif choice == "2":
202
+ return self._create_interactive_dataset()
203
+ elif choice == "3":
204
+ return self._create_sample_dataset()
205
+ else:
206
+ print("\033[31m✗\033[0m Invalid selection")
207
+ return None
208
+
209
+ def _load_existing_dataset(self) -> Optional[Path]:
210
+ """Load an existing dataset file."""
211
+ while True:
212
+ file_path = input("\n\033[96m▶\033[0m Path to dataset file: ").strip()
213
+ if not file_path:
214
+ return None
215
+
216
+ # Expand user path
217
+ file_path = Path(file_path).expanduser()
218
+
219
+ if file_path.exists():
220
+ # Validate dataset format
221
+ print(f"\033[96m⚡\033[0m Validating dataset...")
222
+ valid, message, processed_path = self.dataset_preparer.validate_dataset(file_path)
223
+ if valid:
224
+ print(f"\033[32m✓\033[0m {message}")
225
+ return processed_path
226
+ else:
227
+ print(f"\033[31m✗\033[0m {message}")
228
+ retry = input("\n\033[96m▶\033[0m Try another file? (\033[93my\033[0m/\033[2mn\033[0m): ").strip().lower()
229
+ if retry not in ['y', 'yes']:
230
+ return None
231
+ else:
232
+ print(f"\033[31m✗\033[0m File not found: {file_path}")
233
+ retry = input("\n\033[96m▶\033[0m Try another file? (\033[93my\033[0m/\033[2mn\033[0m): ").strip().lower()
234
+ if retry not in ['y', 'yes']:
235
+ return None
236
+
237
+ def _create_interactive_dataset(self) -> Optional[Path]:
238
+ """Create a dataset interactively."""
239
+ width = min(self.get_terminal_width() - 2, 70)
240
+
241
+ print()
242
+ self.cli.print_box_header("Interactive Dataset Creation", width)
243
+ self.cli.print_empty_line(width)
244
+ self.cli.print_box_line(" Enter prompt-response pairs.", width)
245
+ self.cli.print_box_line(" Type '\033[93mdone\033[0m' when finished.", width)
246
+ self.cli.print_box_line(" \033[2mMinimum 5 examples recommended.\033[0m", width)
247
+ self.cli.print_empty_line(width)
248
+ self.cli.print_box_footer(width)
249
+
250
+ examples = []
251
+ example_num = 1
252
+
253
+ while True:
254
+ print(f"\n\033[96mExample {example_num}:\033[0m")
255
+ prompt = input(" \033[96m▶\033[0m Prompt: ").strip()
256
+
257
+ if prompt.lower() == "done":
258
+ if len(examples) < 5:
259
+ print(f"\033[93m⚠\033[0m You have {len(examples)} examples. Minimum recommended: 5")
260
+ cont = input("\033[96m▶\033[0m Continue anyway? (\033[2my\033[0m/\033[93mN\033[0m): ").strip().lower()
261
+ if cont != 'y':
262
+ continue
263
+ break
264
+
265
+ if not prompt:
266
+ break
267
+
268
+ response = input(" \033[96m▶\033[0m Response: ").strip()
269
+ if not response:
270
+ print("\033[31m✗\033[0m Response required")
271
+ continue
272
+
273
+ examples.append({
274
+ "prompt": prompt,
275
+ "response": response
276
+ })
277
+
278
+ example_num += 1
279
+ print("\033[32m✓\033[0m Added")
280
+
281
+ if not examples:
282
+ print("\033[31m✗\033[0m No examples provided")
283
+ return None
284
+
285
+ # Save to temporary file
286
+ dataset_path = Path.home() / ".cortex" / "temp_datasets" / "interactive_dataset.jsonl"
287
+ dataset_path.parent.mkdir(parents=True, exist_ok=True)
288
+
289
+ with open(dataset_path, 'w') as f:
290
+ for example in examples:
291
+ f.write(json.dumps(example) + '\n')
292
+
293
+ print(f"\033[32m✓\033[0m Created dataset with {len(examples)} examples")
294
+ return dataset_path
295
+
296
+ def _create_sample_dataset(self) -> Optional[Path]:
297
+ """Create a sample dataset for testing."""
298
+ print("\n\033[96m⚡\033[0m Creating sample dataset...")
299
+
300
+ dataset_path = self.dataset_preparer.create_sample_dataset("general")
301
+ print(f"\033[32m✓\033[0m Sample dataset created (5 examples)")
302
+
303
+ return dataset_path
304
+
305
+ def _configure_training(self, base_model: str, dataset_path: Path) -> Optional[TrainingConfig]:
306
+ """Configure training settings using intelligent presets."""
307
+ width = min(self.get_terminal_width() - 2, 70)
308
+
309
+ # Get model information for smart configuration
310
+ model_info = self._get_model_info(base_model)
311
+ model_size_gb = model_info.get('size_gb', 1.0) if model_info else 1.0
312
+ model_path = str(model_info.get('path', '')) if model_info else None
313
+
314
+ # Analyze model and dataset for smart defaults with accurate parameter detection
315
+ model_category, estimated_params = SmartConfigFactory.categorize_model_size(
316
+ model_size_gb, self.model_manager, model_path
317
+ )
318
+ dataset_info = SmartConfigFactory.analyze_dataset(dataset_path)
319
+
320
+ # Show intelligent configuration dialog
321
+ print()
322
+ self.cli.print_box_header("Smart Training Configuration", width)
323
+ self.cli.print_empty_line(width)
324
+
325
+ # Show detected characteristics
326
+ self.cli.print_box_line(f" \033[96mDetected:\033[0m", width)
327
+ self.cli.print_box_line(f" Model: \033[93m{model_category.title()}\033[0m ({estimated_params:.1f}B params, {model_size_gb:.1f}GB)", width)
328
+ self.cli.print_box_line(f" Dataset: \033[93m{dataset_info['size_category'].title()}\033[0m ({dataset_info['size']} examples)", width)
329
+ self.cli.print_box_line(f" Task type: \033[93m{dataset_info['task_type'].title()}\033[0m", width)
330
+
331
+ self.cli.print_empty_line(width)
332
+ self.cli.print_box_separator(width)
333
+ self.cli.print_empty_line(width)
334
+
335
+ self.cli.print_box_line(" \033[96mSelect training preset:\033[0m", width)
336
+ self.cli.print_empty_line(width)
337
+
338
+ # Get preset descriptions
339
+ presets = SmartConfigFactory.get_preset_configs()
340
+
341
+ self.cli.print_box_line(" \033[93m[1]\033[0m Quick \033[2m(fast experimentation)\033[0m", width)
342
+ self.cli.print_box_line(" \033[93m[2]\033[0m Balanced \033[2m(recommended for most cases)\033[0m", width)
343
+ self.cli.print_box_line(" \033[93m[3]\033[0m Quality \033[2m(best results, longer training)\033[0m", width)
344
+ self.cli.print_box_line(" \033[93m[4]\033[0m Expert \033[2m(full customization)\033[0m", width)
345
+
346
+ self.cli.print_empty_line(width)
347
+ self.cli.print_box_footer(width)
348
+
349
+ choice = self.cli.get_input_with_escape("Select preset (1-4)")
350
+
351
+ if choice is None:
352
+ return None
353
+
354
+ preset_map = {
355
+ "1": "quick",
356
+ "2": "balanced",
357
+ "3": "quality"
358
+ }
359
+
360
+ if choice in preset_map:
361
+ # Use smart configuration
362
+ preset = preset_map[choice]
363
+ config = SmartConfigFactory.create_smart_config(
364
+ model_size_gb=model_size_gb,
365
+ dataset_path=dataset_path,
366
+ preset=preset,
367
+ model_manager=self.model_manager,
368
+ model_path=model_path
369
+ )
370
+
371
+ # Show what the smart config decided
372
+ print(f"\n\033[96m⚡\033[0m Smart configuration applied:")
373
+ guidance = SmartConfigFactory.generate_guidance_message(config, base_model)
374
+ print(f" {guidance}")
375
+
376
+ elif choice == "4":
377
+ # Expert mode - full customization
378
+ config = self._expert_configuration(model_size_gb, dataset_path, model_category, model_path)
379
+ if not config:
380
+ return None
381
+ else:
382
+ print("\033[31m✗\033[0m Invalid selection")
383
+ return None
384
+
385
+ # Auto-adjust quantization based on model size
386
+ if model_size_gb > 30 and not config.quantization_bits:
387
+ config.quantization_bits = 4
388
+ print("\033[93m※\033[0m Auto-enabled 4-bit quantization for large model")
389
+ elif model_size_gb > 13 and not config.quantization_bits:
390
+ config.quantization_bits = 8
391
+ print("\033[93m※\033[0m Auto-enabled 8-bit quantization for medium model")
392
+
393
+ return config
394
+
395
+ def _expert_configuration(self, model_size_gb: float, dataset_path: Path, model_category: str, model_path: Optional[str] = None) -> Optional[TrainingConfig]:
396
+ """Expert mode configuration with full customization."""
397
+ width = min(self.get_terminal_width() - 2, 70)
398
+
399
+ print()
400
+ self.cli.print_box_header("Expert Configuration", width)
401
+ self.cli.print_empty_line(width)
402
+ self.cli.print_box_line(" \033[96mConfigure advanced settings:\033[0m", width)
403
+ self.cli.print_box_line(" \033[2mPress Enter to use smart defaults\033[0m", width)
404
+ self.cli.print_empty_line(width)
405
+ self.cli.print_box_footer(width)
406
+
407
+ # Get smart defaults as starting point
408
+ smart_config = SmartConfigFactory.create_smart_config(
409
+ model_size_gb=model_size_gb,
410
+ dataset_path=dataset_path,
411
+ preset="balanced",
412
+ model_manager=self.model_manager,
413
+ model_path=model_path
414
+ )
415
+
416
+ try:
417
+ # Core training parameters
418
+ print("\n\033[96m━━━ Core Training Parameters ━━━\033[0m")
419
+ epochs_str = input(f"\033[96m▶\033[0m Epochs \033[2m[{smart_config.epochs}]\033[0m: ").strip()
420
+ epochs = int(epochs_str) if epochs_str else smart_config.epochs
421
+
422
+ lr_str = input(f"\033[96m▶\033[0m Learning rate \033[2m[{smart_config.learning_rate:.1e}]\033[0m: ").strip()
423
+ learning_rate = float(lr_str) if lr_str else smart_config.learning_rate
424
+
425
+ batch_str = input(f"\033[96m▶\033[0m Batch size \033[2m[{smart_config.batch_size}]\033[0m: ").strip()
426
+ batch_size = int(batch_str) if batch_str else smart_config.batch_size
427
+
428
+ grad_acc_str = input(f"\033[96m▶\033[0m Gradient accumulation steps \033[2m[{smart_config.gradient_accumulation_steps}]\033[0m: ").strip()
429
+ grad_acc_steps = int(grad_acc_str) if grad_acc_str else smart_config.gradient_accumulation_steps
430
+
431
+ # LoRA parameters
432
+ print("\n\033[96m━━━ LoRA Parameters ━━━\033[0m")
433
+ lora_r_str = input(f"\033[96m▶\033[0m LoRA rank \033[2m[{smart_config.lora_r}]\033[0m: ").strip()
434
+ lora_r = int(lora_r_str) if lora_r_str else smart_config.lora_r
435
+
436
+ lora_alpha_str = input(f"\033[96m▶\033[0m LoRA alpha \033[2m[{smart_config.lora_alpha}]\033[0m: ").strip()
437
+ lora_alpha = int(lora_alpha_str) if lora_alpha_str else smart_config.lora_alpha
438
+
439
+ lora_dropout_str = input(f"\033[96m▶\033[0m LoRA dropout \033[2m[{smart_config.lora_dropout}]\033[0m: ").strip()
440
+ lora_dropout = float(lora_dropout_str) if lora_dropout_str else smart_config.lora_dropout
441
+
442
+ # Advanced options (optional)
443
+ print("\n\033[96m━━━ Advanced Options (Optional) ━━━\033[0m")
444
+ weight_decay_str = input(f"\033[96m▶\033[0m Weight decay \033[2m[{smart_config.weight_decay}]\033[0m: ").strip()
445
+ weight_decay = float(weight_decay_str) if weight_decay_str else smart_config.weight_decay
446
+
447
+ warmup_ratio_str = input(f"\033[96m▶\033[0m Warmup ratio \033[2m[{smart_config.warmup_ratio}]\033[0m: ").strip()
448
+ warmup_ratio = float(warmup_ratio_str) if warmup_ratio_str else smart_config.warmup_ratio
449
+
450
+ max_seq_len_str = input(f"\033[96m▶\033[0m Max sequence length \033[2m[{smart_config.max_sequence_length}]\033[0m: ").strip()
451
+ max_seq_len = int(max_seq_len_str) if max_seq_len_str else smart_config.max_sequence_length
452
+
453
+ # Create custom configuration
454
+ config = TrainingConfig(
455
+ epochs=epochs,
456
+ learning_rate=learning_rate,
457
+ batch_size=batch_size,
458
+ gradient_accumulation_steps=grad_acc_steps,
459
+ lora_r=lora_r,
460
+ lora_alpha=lora_alpha,
461
+ lora_dropout=lora_dropout,
462
+ weight_decay=weight_decay,
463
+ warmup_ratio=warmup_ratio,
464
+ max_sequence_length=max_seq_len,
465
+ task_type=smart_config.task_type,
466
+ model_size_category=smart_config.model_size_category,
467
+ estimated_parameters_b=smart_config.estimated_parameters_b,
468
+ auto_configured=False,
469
+ configuration_source="expert"
470
+ )
471
+
472
+ # Validate configuration
473
+ valid, message = config.validate()
474
+ if not valid:
475
+ print(f"\033[31m✗\033[0m Configuration error: {message}")
476
+ return None
477
+
478
+ print(f"\033[32m✓\033[0m Expert configuration created")
479
+ return config
480
+
481
+ except ValueError as e:
482
+ print(f"\033[31m✗\033[0m Invalid value entered: {e}")
483
+ return None
484
+ except KeyboardInterrupt:
485
+ print("\n\033[93m⚠\033[0m Configuration cancelled")
486
+ return None
487
+
488
+ def _get_output_name(self, base_model: str) -> Optional[str]:
489
+ """Get the output model name from user."""
490
+ width = min(self.get_terminal_width() - 2, 70)
491
+
492
+ default_name = f"{base_model}-finetuned"
493
+
494
+ # Show output name dialog
495
+ print()
496
+ self.cli.print_box_header("Output Model", width)
497
+ self.cli.print_empty_line(width)
498
+ self.cli.print_box_line(f" Enter name for fine-tuned model:", width)
499
+ self.cli.print_box_line(f" \033[2mDefault: {default_name}\033[0m", width)
500
+ self.cli.print_empty_line(width)
501
+ self.cli.print_box_footer(width)
502
+
503
+ name = input(f"\n\033[96m▶\033[0m Model name \033[2m[{default_name}]\033[0m: ").strip()
504
+ name = name if name else default_name
505
+
506
+ # Check if name already exists
507
+ existing_models = self._get_available_models()
508
+ if any(model_name == name for model_name, _ in existing_models):
509
+ choice = input(f"\n\033[93m⚠\033[0m Model '{name}' exists. Overwrite? (\033[2my\033[0m/\033[93mN\033[0m): ").strip().lower()
510
+ if choice != 'y':
511
+ return None
512
+
513
+ return name
514
+
515
+ def _confirm_settings(self, base_model: str, dataset_path: Path,
516
+ config: TrainingConfig, output_name: str) -> bool:
517
+ """Show summary and confirm settings."""
518
+ width = min(self.get_terminal_width() - 2, 70)
519
+
520
+ # Count dataset examples
521
+ example_count = sum(1 for _ in open(dataset_path))
522
+
523
+ # Estimate training time
524
+ estimated_time = self._estimate_training_time(example_count, config)
525
+
526
+ # Show summary dialog
527
+ print()
528
+ self.cli.print_box_header("Training Summary", width)
529
+ self.cli.print_empty_line(width)
530
+
531
+ self.cli.print_box_line(" \033[96mConfiguration:\033[0m", width)
532
+ self.cli.print_empty_line(width)
533
+
534
+ self.cli.print_box_line(f" Base model: \033[93m{base_model}\033[0m", width)
535
+ self.cli.print_box_line(f" Output model: \033[93m{output_name}\033[0m", width)
536
+ self.cli.print_box_line(f" Dataset: {dataset_path.name} \033[2m({example_count} examples)\033[0m", width)
537
+
538
+ self.cli.print_empty_line(width)
539
+
540
+ self.cli.print_box_line(f" Model size: \033[93m{config.model_size_category.title()}\033[0m ({config.estimated_parameters_b:.1f}B params)", width)
541
+ self.cli.print_box_line(f" Task type: {config.task_type.title()}", width)
542
+ self.cli.print_box_line(f" Config source: {config.configuration_source.replace('_', ' ').title()}", width)
543
+
544
+ self.cli.print_empty_line(width)
545
+
546
+ self.cli.print_box_line(f" Epochs: {config.epochs}", width)
547
+ self.cli.print_box_line(f" Learning rate: {config.learning_rate:.1e}", width)
548
+ self.cli.print_box_line(f" LoRA rank: {config.lora_r}", width)
549
+ self.cli.print_box_line(f" Batch size: {config.batch_size} (x{config.gradient_accumulation_steps} acc.)", width)
550
+ if config.quantization_bits:
551
+ self.cli.print_box_line(f" Quantization: {config.quantization_bits}-bit", width)
552
+
553
+ self.cli.print_empty_line(width)
554
+ self.cli.print_box_line(f" \033[2mEstimated time: {estimated_time}\033[0m", width)
555
+
556
+ self.cli.print_empty_line(width)
557
+ self.cli.print_box_separator(width)
558
+ self.cli.print_empty_line(width)
559
+
560
+ self.cli.print_box_line(" Start fine-tuning?", width)
561
+ self.cli.print_empty_line(width)
562
+ self.cli.print_box_line(" \033[93m[Y]\033[0m Yes, start training", width)
563
+ self.cli.print_box_line(" \033[93m[N]\033[0m No, cancel", width)
564
+
565
+ self.cli.print_empty_line(width)
566
+ self.cli.print_box_footer(width)
567
+
568
+ choice = input("\n\033[96m▶\033[0m Choice (\033[93my\033[0m/\033[2mn\033[0m): ").strip().lower()
569
+ return choice in ['y', 'yes', '']
570
+
571
+ def _run_training(self, base_model: str, dataset_path: Path,
572
+ config: TrainingConfig, output_name: str) -> bool:
573
+ """Run the actual training."""
574
+ print("\n\033[96m⚡\033[0m Starting fine-tuning...")
575
+
576
+ try:
577
+ # Hard requirement: MLX must be available for fine-tuning.
578
+ if not MLXLoRATrainer.is_available():
579
+ print("\n\033[31m✗\033[0m Fine-tuning requires MLX/Metal, but MLX is not available in this environment.")
580
+ return False
581
+ # Use MLXLoRATrainer for proper LoRA implementation
582
+ self.trainer = MLXLoRATrainer(self.model_manager, self.config)
583
+
584
+ # Progress tracking
585
+ start_time = time.time()
586
+ last_update = start_time
587
+
588
+ def update_progress(epoch, step, loss):
589
+ nonlocal last_update
590
+ current_time = time.time()
591
+
592
+ # Update every 0.5 seconds
593
+ if current_time - last_update > 0.5:
594
+ elapsed = current_time - start_time
595
+ progress = ((epoch * 100) + min(step, 99)) / (config.epochs * 100)
596
+
597
+ # Create progress bar
598
+ bar_width = 30
599
+ filled = int(bar_width * progress)
600
+ bar = "█" * filled + "░" * (bar_width - filled)
601
+
602
+ # Print progress
603
+ sys.stdout.write(f"\r {bar} {progress*100:.0f}% | Epoch {epoch+1}/{config.epochs} | Loss: {loss:.4f}")
604
+ sys.stdout.flush()
605
+ last_update = current_time
606
+
607
+ # Run training
608
+ success = self.trainer.train(
609
+ base_model_name=base_model,
610
+ dataset_path=dataset_path,
611
+ output_name=output_name,
612
+ training_config=config,
613
+ progress_callback=update_progress
614
+ )
615
+
616
+ print() # New line after progress
617
+
618
+ if success:
619
+ print(f"\n\033[32m✓\033[0m Fine-tuning completed!")
620
+
621
+ # Show where the model was saved
622
+ mlx_path = Path.home() / ".cortex" / "mlx_models" / output_name
623
+ if mlx_path.exists():
624
+ print(f"\n\033[96m📍\033[0m Model saved to: \033[93m{mlx_path}\033[0m")
625
+ print(f"\n\033[96m💡\033[0m To load your fine-tuned model:")
626
+ print(f" \033[93m/model {mlx_path}\033[0m")
627
+
628
+ # Check if adapter weights exist
629
+ adapter_file = mlx_path / "adapter.safetensors"
630
+ if adapter_file.exists():
631
+ size_mb = adapter_file.stat().st_size / (1024 * 1024)
632
+ print(f"\n\033[2m LoRA adapter size: {size_mb:.1f} MB\033[0m")
633
+ print(f"\033[2m Base model: {base_model}\033[0m")
634
+
635
+ return True
636
+ else:
637
+ print("\n\033[31m✗\033[0m Fine-tuning failed")
638
+ return False
639
+
640
+ except KeyboardInterrupt:
641
+ print("\n\n\033[93m⚠\033[0m Training interrupted by user")
642
+ return False
643
+ except Exception as e:
644
+ logger.error(f"Training failed: {e}")
645
+ print(f"\n\n\033[31m✗\033[0m Training error: {e}")
646
+ import traceback
647
+ traceback.print_exc()
648
+ return False
649
+
650
+ def _get_available_models(self) -> List[Tuple[str, Dict[str, Any]]]:
651
+ """Get list of available models."""
652
+ models = []
653
+
654
+ # Get models from model manager
655
+ discovered = self.model_manager.discover_available_models()
656
+
657
+ for model_info in discovered:
658
+ try:
659
+ name = model_info.get('name', 'Unknown')
660
+ info = {
661
+ 'path': Path(model_info.get('path', '')),
662
+ 'format': model_info.get('format', 'Unknown'),
663
+ 'size_gb': model_info.get('size_gb', 0.0)
664
+ }
665
+ models.append((name, info))
666
+ except Exception as e:
667
+ logger.debug(f"Error processing model info: {e}")
668
+ continue
669
+
670
+ return sorted(models, key=lambda x: x[0])
671
+
672
+ def _get_model_info(self, model_name: str) -> Optional[Dict[str, Any]]:
673
+ """Get information about a model."""
674
+ models = self._get_available_models()
675
+ for name, info in models:
676
+ if name == model_name:
677
+ return info
678
+ return None
679
+
680
+ def _estimate_training_time(self, example_count: int, config: TrainingConfig) -> str:
681
+ """Estimate training time based on dataset size, epochs, and model characteristics."""
682
+ # Base time estimation adjusted for model size and batch settings
683
+ base_seconds_per_example = {
684
+ "tiny": 0.1, # Very fast for small models
685
+ "small": 0.3, # Fast
686
+ "medium": 0.7, # Standard
687
+ "large": 1.5, # Slower for large models
688
+ "xlarge": 3.0 # Much slower
689
+ }.get(config.model_size_category, 0.7)
690
+
691
+ # Adjust for gradient accumulation (more accumulation = fewer actual updates)
692
+ effective_batch_size = config.batch_size * config.gradient_accumulation_steps
693
+ batch_factor = max(0.5, 1.0 / (effective_batch_size ** 0.5)) # Larger batches are more efficient
694
+
695
+ # Adjust for quantization (if enabled, training is faster)
696
+ quant_factor = 0.7 if config.quantization_bits else 1.0
697
+
698
+ # Calculate total time
699
+ adjusted_time_per_example = base_seconds_per_example * batch_factor * quant_factor
700
+ total_seconds = example_count * config.epochs * adjusted_time_per_example
701
+
702
+ if total_seconds < 60:
703
+ return f"~{int(total_seconds)} seconds"
704
+ elif total_seconds < 3600:
705
+ return f"~{int(total_seconds / 60)} minutes"
706
+ else:
707
+ return f"~{total_seconds / 3600:.1f} hours"