cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,707 @@
|
|
|
1
|
+
"""Interactive fine-tuning wizard for Cortex."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional, Dict, Any, List, Tuple
|
|
8
|
+
import json
|
|
9
|
+
import time
|
|
10
|
+
|
|
11
|
+
from cortex.model_manager import ModelManager, ModelFormat
|
|
12
|
+
from cortex.config import Config
|
|
13
|
+
from .trainer import LoRATrainer, TrainingConfig, SmartConfigFactory
|
|
14
|
+
|
|
15
|
+
from .mlx_lora_trainer import MLXLoRATrainer
|
|
16
|
+
from .dataset import DatasetPreparer
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FineTuneWizard:
|
|
22
|
+
"""Interactive wizard for fine-tuning models - Cortex style."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, model_manager: ModelManager, config: Config):
|
|
25
|
+
"""Initialize the fine-tuning wizard."""
|
|
26
|
+
self.model_manager = model_manager
|
|
27
|
+
self.config = config
|
|
28
|
+
self.trainer = None
|
|
29
|
+
self.dataset_preparer = DatasetPreparer()
|
|
30
|
+
self.cli = None # Will be set by CLI when running
|
|
31
|
+
|
|
32
|
+
def get_terminal_width(self) -> int:
|
|
33
|
+
"""Get terminal width."""
|
|
34
|
+
if self.cli:
|
|
35
|
+
return self.cli.get_terminal_width()
|
|
36
|
+
return 80
|
|
37
|
+
|
|
38
|
+
def start(self) -> Tuple[bool, str]:
|
|
39
|
+
"""Start the interactive fine-tuning experience."""
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
# Hard block if MLX is not installed/available. Fine-tuning depends on it.
|
|
43
|
+
if not MLXLoRATrainer.is_available():
|
|
44
|
+
message = "Fine-tuning requires MLX/Metal, but the MLX stack is not available in this environment."
|
|
45
|
+
print(f"\n\033[31m✗\033[0m {message}")
|
|
46
|
+
return False, message
|
|
47
|
+
# Step 1: Select base model
|
|
48
|
+
base_model = self._select_base_model()
|
|
49
|
+
if not base_model:
|
|
50
|
+
return False, "Fine-tuning cancelled"
|
|
51
|
+
|
|
52
|
+
# Step 2: Select or prepare dataset
|
|
53
|
+
dataset_path = self._prepare_dataset()
|
|
54
|
+
if not dataset_path:
|
|
55
|
+
return False, "Fine-tuning cancelled"
|
|
56
|
+
|
|
57
|
+
# Step 3: Configure training settings
|
|
58
|
+
training_config = self._configure_training(base_model, dataset_path)
|
|
59
|
+
if not training_config:
|
|
60
|
+
return False, "Fine-tuning cancelled"
|
|
61
|
+
|
|
62
|
+
# Step 4: Choose output name
|
|
63
|
+
output_name = self._get_output_name(base_model)
|
|
64
|
+
if not output_name:
|
|
65
|
+
return False, "Fine-tuning cancelled"
|
|
66
|
+
|
|
67
|
+
# Step 5: Confirm and start training
|
|
68
|
+
if not self._confirm_settings(base_model, dataset_path, training_config, output_name):
|
|
69
|
+
return False, "Fine-tuning cancelled"
|
|
70
|
+
|
|
71
|
+
# Step 6: Run training
|
|
72
|
+
success = self._run_training(base_model, dataset_path, training_config, output_name)
|
|
73
|
+
|
|
74
|
+
if success:
|
|
75
|
+
return True, f"Fine-tuned model saved as: {output_name}"
|
|
76
|
+
else:
|
|
77
|
+
return False, "Training failed"
|
|
78
|
+
|
|
79
|
+
except KeyboardInterrupt:
|
|
80
|
+
print("\n\033[93m⚠\033[0m Fine-tuning cancelled by user")
|
|
81
|
+
return False, "Fine-tuning cancelled"
|
|
82
|
+
except FileNotFoundError as e:
|
|
83
|
+
logger.error(f"File not found: {e}")
|
|
84
|
+
print(f"\n\033[31m✗\033[0m File not found: {e}")
|
|
85
|
+
return False, f"File not found: {str(e)}"
|
|
86
|
+
except PermissionError as e:
|
|
87
|
+
logger.error(f"Permission denied: {e}")
|
|
88
|
+
print(f"\n\033[31m✗\033[0m Permission denied: {e}")
|
|
89
|
+
return False, f"Permission denied: {str(e)}"
|
|
90
|
+
except Exception as e:
|
|
91
|
+
logger.error(f"Fine-tuning failed: {e}")
|
|
92
|
+
print(f"\n\033[31m✗\033[0m Unexpected error: {e}")
|
|
93
|
+
import traceback
|
|
94
|
+
traceback.print_exc()
|
|
95
|
+
return False, f"Fine-tuning failed: {str(e)}"
|
|
96
|
+
|
|
97
|
+
def _select_base_model(self) -> Optional[str]:
|
|
98
|
+
"""Select the base model to fine-tune."""
|
|
99
|
+
width = min(self.get_terminal_width() - 2, 70)
|
|
100
|
+
|
|
101
|
+
# Get available models
|
|
102
|
+
models = self._get_available_models()
|
|
103
|
+
|
|
104
|
+
if not models:
|
|
105
|
+
print("\033[31m✗\033[0m No models available. Use \033[93m/download\033[0m to get models.")
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
# Check if a model is already loaded
|
|
109
|
+
if self.model_manager.current_model:
|
|
110
|
+
current_model_name = self.model_manager.current_model
|
|
111
|
+
|
|
112
|
+
# Create dialog box for current model
|
|
113
|
+
print()
|
|
114
|
+
self.cli.print_box_header("Fine-Tuning Setup", width)
|
|
115
|
+
self.cli.print_empty_line(width)
|
|
116
|
+
|
|
117
|
+
self.cli.print_box_line(f" \033[96mCurrent Model:\033[0m \033[93m{current_model_name}\033[0m", width)
|
|
118
|
+
|
|
119
|
+
self.cli.print_empty_line(width)
|
|
120
|
+
self.cli.print_box_separator(width)
|
|
121
|
+
self.cli.print_empty_line(width)
|
|
122
|
+
|
|
123
|
+
self.cli.print_box_line(" Use this model for fine-tuning?", width)
|
|
124
|
+
self.cli.print_empty_line(width)
|
|
125
|
+
self.cli.print_box_line(" \033[93m[Y]\033[0m Yes, use this model", width)
|
|
126
|
+
self.cli.print_box_line(" \033[93m[N]\033[0m No, select another", width)
|
|
127
|
+
|
|
128
|
+
self.cli.print_empty_line(width)
|
|
129
|
+
self.cli.print_box_footer(width)
|
|
130
|
+
|
|
131
|
+
choice = input("\n\033[96m▶\033[0m Choice (\033[93my\033[0m/\033[2mn\033[0m): ").strip().lower()
|
|
132
|
+
|
|
133
|
+
if choice in ['y', 'yes', '']:
|
|
134
|
+
print(f"\033[32m✓\033[0m Using: {current_model_name}")
|
|
135
|
+
return current_model_name
|
|
136
|
+
|
|
137
|
+
# Show model selection dialog
|
|
138
|
+
print()
|
|
139
|
+
self.cli.print_box_header("Select Base Model", width)
|
|
140
|
+
self.cli.print_empty_line(width)
|
|
141
|
+
|
|
142
|
+
# List models with numbers
|
|
143
|
+
for i, (name, info) in enumerate(models[:10], 1):
|
|
144
|
+
size_str = f"{info['size_gb']:.1f}GB"
|
|
145
|
+
format_str = info['format']
|
|
146
|
+
line = f" \033[93m[{i}]\033[0m {name} \033[2m({size_str}, {format_str})\033[0m"
|
|
147
|
+
self.cli.print_box_line(line, width)
|
|
148
|
+
|
|
149
|
+
if len(models) > 10:
|
|
150
|
+
self.cli.print_empty_line(width)
|
|
151
|
+
self.cli.print_box_line(f" \033[2m... and {len(models) - 10} more models available\033[0m", width)
|
|
152
|
+
|
|
153
|
+
self.cli.print_empty_line(width)
|
|
154
|
+
self.cli.print_box_footer(width)
|
|
155
|
+
|
|
156
|
+
# Get user selection
|
|
157
|
+
choice = self.cli.get_input_with_escape(f"Select model (1-{len(models)})")
|
|
158
|
+
|
|
159
|
+
if choice is None:
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
idx = int(choice) - 1
|
|
164
|
+
if 0 <= idx < len(models):
|
|
165
|
+
selected_model = models[idx][0]
|
|
166
|
+
print(f"\033[32m✓\033[0m Selected: {selected_model}")
|
|
167
|
+
return selected_model
|
|
168
|
+
else:
|
|
169
|
+
print("\033[31m✗\033[0m Invalid selection")
|
|
170
|
+
return None
|
|
171
|
+
except ValueError:
|
|
172
|
+
print("\033[31m✗\033[0m Please enter a valid number")
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
def _prepare_dataset(self) -> Optional[Path]:
|
|
176
|
+
"""Prepare the training dataset."""
|
|
177
|
+
width = min(self.get_terminal_width() - 2, 70)
|
|
178
|
+
|
|
179
|
+
# Show dataset options dialog
|
|
180
|
+
print()
|
|
181
|
+
self.cli.print_box_header("Training Data", width)
|
|
182
|
+
self.cli.print_empty_line(width)
|
|
183
|
+
|
|
184
|
+
self.cli.print_box_line(" \033[96mSelect data source:\033[0m", width)
|
|
185
|
+
self.cli.print_empty_line(width)
|
|
186
|
+
|
|
187
|
+
self.cli.print_box_line(" \033[93m[1]\033[0m Load from file \033[2m(JSONL/CSV/TXT)\033[0m", width)
|
|
188
|
+
self.cli.print_box_line(" \033[93m[2]\033[0m Create interactively", width)
|
|
189
|
+
self.cli.print_box_line(" \033[93m[3]\033[0m Use sample dataset \033[2m(for testing)\033[0m", width)
|
|
190
|
+
|
|
191
|
+
self.cli.print_empty_line(width)
|
|
192
|
+
self.cli.print_box_footer(width)
|
|
193
|
+
|
|
194
|
+
choice = self.cli.get_input_with_escape("Select option (1-3)")
|
|
195
|
+
|
|
196
|
+
if choice is None:
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
if choice == "1":
|
|
200
|
+
return self._load_existing_dataset()
|
|
201
|
+
elif choice == "2":
|
|
202
|
+
return self._create_interactive_dataset()
|
|
203
|
+
elif choice == "3":
|
|
204
|
+
return self._create_sample_dataset()
|
|
205
|
+
else:
|
|
206
|
+
print("\033[31m✗\033[0m Invalid selection")
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
def _load_existing_dataset(self) -> Optional[Path]:
|
|
210
|
+
"""Load an existing dataset file."""
|
|
211
|
+
while True:
|
|
212
|
+
file_path = input("\n\033[96m▶\033[0m Path to dataset file: ").strip()
|
|
213
|
+
if not file_path:
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
# Expand user path
|
|
217
|
+
file_path = Path(file_path).expanduser()
|
|
218
|
+
|
|
219
|
+
if file_path.exists():
|
|
220
|
+
# Validate dataset format
|
|
221
|
+
print(f"\033[96m⚡\033[0m Validating dataset...")
|
|
222
|
+
valid, message, processed_path = self.dataset_preparer.validate_dataset(file_path)
|
|
223
|
+
if valid:
|
|
224
|
+
print(f"\033[32m✓\033[0m {message}")
|
|
225
|
+
return processed_path
|
|
226
|
+
else:
|
|
227
|
+
print(f"\033[31m✗\033[0m {message}")
|
|
228
|
+
retry = input("\n\033[96m▶\033[0m Try another file? (\033[93my\033[0m/\033[2mn\033[0m): ").strip().lower()
|
|
229
|
+
if retry not in ['y', 'yes']:
|
|
230
|
+
return None
|
|
231
|
+
else:
|
|
232
|
+
print(f"\033[31m✗\033[0m File not found: {file_path}")
|
|
233
|
+
retry = input("\n\033[96m▶\033[0m Try another file? (\033[93my\033[0m/\033[2mn\033[0m): ").strip().lower()
|
|
234
|
+
if retry not in ['y', 'yes']:
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
def _create_interactive_dataset(self) -> Optional[Path]:
|
|
238
|
+
"""Create a dataset interactively."""
|
|
239
|
+
width = min(self.get_terminal_width() - 2, 70)
|
|
240
|
+
|
|
241
|
+
print()
|
|
242
|
+
self.cli.print_box_header("Interactive Dataset Creation", width)
|
|
243
|
+
self.cli.print_empty_line(width)
|
|
244
|
+
self.cli.print_box_line(" Enter prompt-response pairs.", width)
|
|
245
|
+
self.cli.print_box_line(" Type '\033[93mdone\033[0m' when finished.", width)
|
|
246
|
+
self.cli.print_box_line(" \033[2mMinimum 5 examples recommended.\033[0m", width)
|
|
247
|
+
self.cli.print_empty_line(width)
|
|
248
|
+
self.cli.print_box_footer(width)
|
|
249
|
+
|
|
250
|
+
examples = []
|
|
251
|
+
example_num = 1
|
|
252
|
+
|
|
253
|
+
while True:
|
|
254
|
+
print(f"\n\033[96mExample {example_num}:\033[0m")
|
|
255
|
+
prompt = input(" \033[96m▶\033[0m Prompt: ").strip()
|
|
256
|
+
|
|
257
|
+
if prompt.lower() == "done":
|
|
258
|
+
if len(examples) < 5:
|
|
259
|
+
print(f"\033[93m⚠\033[0m You have {len(examples)} examples. Minimum recommended: 5")
|
|
260
|
+
cont = input("\033[96m▶\033[0m Continue anyway? (\033[2my\033[0m/\033[93mN\033[0m): ").strip().lower()
|
|
261
|
+
if cont != 'y':
|
|
262
|
+
continue
|
|
263
|
+
break
|
|
264
|
+
|
|
265
|
+
if not prompt:
|
|
266
|
+
break
|
|
267
|
+
|
|
268
|
+
response = input(" \033[96m▶\033[0m Response: ").strip()
|
|
269
|
+
if not response:
|
|
270
|
+
print("\033[31m✗\033[0m Response required")
|
|
271
|
+
continue
|
|
272
|
+
|
|
273
|
+
examples.append({
|
|
274
|
+
"prompt": prompt,
|
|
275
|
+
"response": response
|
|
276
|
+
})
|
|
277
|
+
|
|
278
|
+
example_num += 1
|
|
279
|
+
print("\033[32m✓\033[0m Added")
|
|
280
|
+
|
|
281
|
+
if not examples:
|
|
282
|
+
print("\033[31m✗\033[0m No examples provided")
|
|
283
|
+
return None
|
|
284
|
+
|
|
285
|
+
# Save to temporary file
|
|
286
|
+
dataset_path = Path.home() / ".cortex" / "temp_datasets" / "interactive_dataset.jsonl"
|
|
287
|
+
dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
|
288
|
+
|
|
289
|
+
with open(dataset_path, 'w') as f:
|
|
290
|
+
for example in examples:
|
|
291
|
+
f.write(json.dumps(example) + '\n')
|
|
292
|
+
|
|
293
|
+
print(f"\033[32m✓\033[0m Created dataset with {len(examples)} examples")
|
|
294
|
+
return dataset_path
|
|
295
|
+
|
|
296
|
+
def _create_sample_dataset(self) -> Optional[Path]:
|
|
297
|
+
"""Create a sample dataset for testing."""
|
|
298
|
+
print("\n\033[96m⚡\033[0m Creating sample dataset...")
|
|
299
|
+
|
|
300
|
+
dataset_path = self.dataset_preparer.create_sample_dataset("general")
|
|
301
|
+
print(f"\033[32m✓\033[0m Sample dataset created (5 examples)")
|
|
302
|
+
|
|
303
|
+
return dataset_path
|
|
304
|
+
|
|
305
|
+
def _configure_training(self, base_model: str, dataset_path: Path) -> Optional[TrainingConfig]:
|
|
306
|
+
"""Configure training settings using intelligent presets."""
|
|
307
|
+
width = min(self.get_terminal_width() - 2, 70)
|
|
308
|
+
|
|
309
|
+
# Get model information for smart configuration
|
|
310
|
+
model_info = self._get_model_info(base_model)
|
|
311
|
+
model_size_gb = model_info.get('size_gb', 1.0) if model_info else 1.0
|
|
312
|
+
model_path = str(model_info.get('path', '')) if model_info else None
|
|
313
|
+
|
|
314
|
+
# Analyze model and dataset for smart defaults with accurate parameter detection
|
|
315
|
+
model_category, estimated_params = SmartConfigFactory.categorize_model_size(
|
|
316
|
+
model_size_gb, self.model_manager, model_path
|
|
317
|
+
)
|
|
318
|
+
dataset_info = SmartConfigFactory.analyze_dataset(dataset_path)
|
|
319
|
+
|
|
320
|
+
# Show intelligent configuration dialog
|
|
321
|
+
print()
|
|
322
|
+
self.cli.print_box_header("Smart Training Configuration", width)
|
|
323
|
+
self.cli.print_empty_line(width)
|
|
324
|
+
|
|
325
|
+
# Show detected characteristics
|
|
326
|
+
self.cli.print_box_line(f" \033[96mDetected:\033[0m", width)
|
|
327
|
+
self.cli.print_box_line(f" Model: \033[93m{model_category.title()}\033[0m ({estimated_params:.1f}B params, {model_size_gb:.1f}GB)", width)
|
|
328
|
+
self.cli.print_box_line(f" Dataset: \033[93m{dataset_info['size_category'].title()}\033[0m ({dataset_info['size']} examples)", width)
|
|
329
|
+
self.cli.print_box_line(f" Task type: \033[93m{dataset_info['task_type'].title()}\033[0m", width)
|
|
330
|
+
|
|
331
|
+
self.cli.print_empty_line(width)
|
|
332
|
+
self.cli.print_box_separator(width)
|
|
333
|
+
self.cli.print_empty_line(width)
|
|
334
|
+
|
|
335
|
+
self.cli.print_box_line(" \033[96mSelect training preset:\033[0m", width)
|
|
336
|
+
self.cli.print_empty_line(width)
|
|
337
|
+
|
|
338
|
+
# Get preset descriptions
|
|
339
|
+
presets = SmartConfigFactory.get_preset_configs()
|
|
340
|
+
|
|
341
|
+
self.cli.print_box_line(" \033[93m[1]\033[0m Quick \033[2m(fast experimentation)\033[0m", width)
|
|
342
|
+
self.cli.print_box_line(" \033[93m[2]\033[0m Balanced \033[2m(recommended for most cases)\033[0m", width)
|
|
343
|
+
self.cli.print_box_line(" \033[93m[3]\033[0m Quality \033[2m(best results, longer training)\033[0m", width)
|
|
344
|
+
self.cli.print_box_line(" \033[93m[4]\033[0m Expert \033[2m(full customization)\033[0m", width)
|
|
345
|
+
|
|
346
|
+
self.cli.print_empty_line(width)
|
|
347
|
+
self.cli.print_box_footer(width)
|
|
348
|
+
|
|
349
|
+
choice = self.cli.get_input_with_escape("Select preset (1-4)")
|
|
350
|
+
|
|
351
|
+
if choice is None:
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
preset_map = {
|
|
355
|
+
"1": "quick",
|
|
356
|
+
"2": "balanced",
|
|
357
|
+
"3": "quality"
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
if choice in preset_map:
|
|
361
|
+
# Use smart configuration
|
|
362
|
+
preset = preset_map[choice]
|
|
363
|
+
config = SmartConfigFactory.create_smart_config(
|
|
364
|
+
model_size_gb=model_size_gb,
|
|
365
|
+
dataset_path=dataset_path,
|
|
366
|
+
preset=preset,
|
|
367
|
+
model_manager=self.model_manager,
|
|
368
|
+
model_path=model_path
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Show what the smart config decided
|
|
372
|
+
print(f"\n\033[96m⚡\033[0m Smart configuration applied:")
|
|
373
|
+
guidance = SmartConfigFactory.generate_guidance_message(config, base_model)
|
|
374
|
+
print(f" {guidance}")
|
|
375
|
+
|
|
376
|
+
elif choice == "4":
|
|
377
|
+
# Expert mode - full customization
|
|
378
|
+
config = self._expert_configuration(model_size_gb, dataset_path, model_category, model_path)
|
|
379
|
+
if not config:
|
|
380
|
+
return None
|
|
381
|
+
else:
|
|
382
|
+
print("\033[31m✗\033[0m Invalid selection")
|
|
383
|
+
return None
|
|
384
|
+
|
|
385
|
+
# Auto-adjust quantization based on model size
|
|
386
|
+
if model_size_gb > 30 and not config.quantization_bits:
|
|
387
|
+
config.quantization_bits = 4
|
|
388
|
+
print("\033[93m※\033[0m Auto-enabled 4-bit quantization for large model")
|
|
389
|
+
elif model_size_gb > 13 and not config.quantization_bits:
|
|
390
|
+
config.quantization_bits = 8
|
|
391
|
+
print("\033[93m※\033[0m Auto-enabled 8-bit quantization for medium model")
|
|
392
|
+
|
|
393
|
+
return config
|
|
394
|
+
|
|
395
|
+
def _expert_configuration(self, model_size_gb: float, dataset_path: Path, model_category: str, model_path: Optional[str] = None) -> Optional[TrainingConfig]:
|
|
396
|
+
"""Expert mode configuration with full customization."""
|
|
397
|
+
width = min(self.get_terminal_width() - 2, 70)
|
|
398
|
+
|
|
399
|
+
print()
|
|
400
|
+
self.cli.print_box_header("Expert Configuration", width)
|
|
401
|
+
self.cli.print_empty_line(width)
|
|
402
|
+
self.cli.print_box_line(" \033[96mConfigure advanced settings:\033[0m", width)
|
|
403
|
+
self.cli.print_box_line(" \033[2mPress Enter to use smart defaults\033[0m", width)
|
|
404
|
+
self.cli.print_empty_line(width)
|
|
405
|
+
self.cli.print_box_footer(width)
|
|
406
|
+
|
|
407
|
+
# Get smart defaults as starting point
|
|
408
|
+
smart_config = SmartConfigFactory.create_smart_config(
|
|
409
|
+
model_size_gb=model_size_gb,
|
|
410
|
+
dataset_path=dataset_path,
|
|
411
|
+
preset="balanced",
|
|
412
|
+
model_manager=self.model_manager,
|
|
413
|
+
model_path=model_path
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
try:
|
|
417
|
+
# Core training parameters
|
|
418
|
+
print("\n\033[96m━━━ Core Training Parameters ━━━\033[0m")
|
|
419
|
+
epochs_str = input(f"\033[96m▶\033[0m Epochs \033[2m[{smart_config.epochs}]\033[0m: ").strip()
|
|
420
|
+
epochs = int(epochs_str) if epochs_str else smart_config.epochs
|
|
421
|
+
|
|
422
|
+
lr_str = input(f"\033[96m▶\033[0m Learning rate \033[2m[{smart_config.learning_rate:.1e}]\033[0m: ").strip()
|
|
423
|
+
learning_rate = float(lr_str) if lr_str else smart_config.learning_rate
|
|
424
|
+
|
|
425
|
+
batch_str = input(f"\033[96m▶\033[0m Batch size \033[2m[{smart_config.batch_size}]\033[0m: ").strip()
|
|
426
|
+
batch_size = int(batch_str) if batch_str else smart_config.batch_size
|
|
427
|
+
|
|
428
|
+
grad_acc_str = input(f"\033[96m▶\033[0m Gradient accumulation steps \033[2m[{smart_config.gradient_accumulation_steps}]\033[0m: ").strip()
|
|
429
|
+
grad_acc_steps = int(grad_acc_str) if grad_acc_str else smart_config.gradient_accumulation_steps
|
|
430
|
+
|
|
431
|
+
# LoRA parameters
|
|
432
|
+
print("\n\033[96m━━━ LoRA Parameters ━━━\033[0m")
|
|
433
|
+
lora_r_str = input(f"\033[96m▶\033[0m LoRA rank \033[2m[{smart_config.lora_r}]\033[0m: ").strip()
|
|
434
|
+
lora_r = int(lora_r_str) if lora_r_str else smart_config.lora_r
|
|
435
|
+
|
|
436
|
+
lora_alpha_str = input(f"\033[96m▶\033[0m LoRA alpha \033[2m[{smart_config.lora_alpha}]\033[0m: ").strip()
|
|
437
|
+
lora_alpha = int(lora_alpha_str) if lora_alpha_str else smart_config.lora_alpha
|
|
438
|
+
|
|
439
|
+
lora_dropout_str = input(f"\033[96m▶\033[0m LoRA dropout \033[2m[{smart_config.lora_dropout}]\033[0m: ").strip()
|
|
440
|
+
lora_dropout = float(lora_dropout_str) if lora_dropout_str else smart_config.lora_dropout
|
|
441
|
+
|
|
442
|
+
# Advanced options (optional)
|
|
443
|
+
print("\n\033[96m━━━ Advanced Options (Optional) ━━━\033[0m")
|
|
444
|
+
weight_decay_str = input(f"\033[96m▶\033[0m Weight decay \033[2m[{smart_config.weight_decay}]\033[0m: ").strip()
|
|
445
|
+
weight_decay = float(weight_decay_str) if weight_decay_str else smart_config.weight_decay
|
|
446
|
+
|
|
447
|
+
warmup_ratio_str = input(f"\033[96m▶\033[0m Warmup ratio \033[2m[{smart_config.warmup_ratio}]\033[0m: ").strip()
|
|
448
|
+
warmup_ratio = float(warmup_ratio_str) if warmup_ratio_str else smart_config.warmup_ratio
|
|
449
|
+
|
|
450
|
+
max_seq_len_str = input(f"\033[96m▶\033[0m Max sequence length \033[2m[{smart_config.max_sequence_length}]\033[0m: ").strip()
|
|
451
|
+
max_seq_len = int(max_seq_len_str) if max_seq_len_str else smart_config.max_sequence_length
|
|
452
|
+
|
|
453
|
+
# Create custom configuration
|
|
454
|
+
config = TrainingConfig(
|
|
455
|
+
epochs=epochs,
|
|
456
|
+
learning_rate=learning_rate,
|
|
457
|
+
batch_size=batch_size,
|
|
458
|
+
gradient_accumulation_steps=grad_acc_steps,
|
|
459
|
+
lora_r=lora_r,
|
|
460
|
+
lora_alpha=lora_alpha,
|
|
461
|
+
lora_dropout=lora_dropout,
|
|
462
|
+
weight_decay=weight_decay,
|
|
463
|
+
warmup_ratio=warmup_ratio,
|
|
464
|
+
max_sequence_length=max_seq_len,
|
|
465
|
+
task_type=smart_config.task_type,
|
|
466
|
+
model_size_category=smart_config.model_size_category,
|
|
467
|
+
estimated_parameters_b=smart_config.estimated_parameters_b,
|
|
468
|
+
auto_configured=False,
|
|
469
|
+
configuration_source="expert"
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Validate configuration
|
|
473
|
+
valid, message = config.validate()
|
|
474
|
+
if not valid:
|
|
475
|
+
print(f"\033[31m✗\033[0m Configuration error: {message}")
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
print(f"\033[32m✓\033[0m Expert configuration created")
|
|
479
|
+
return config
|
|
480
|
+
|
|
481
|
+
except ValueError as e:
|
|
482
|
+
print(f"\033[31m✗\033[0m Invalid value entered: {e}")
|
|
483
|
+
return None
|
|
484
|
+
except KeyboardInterrupt:
|
|
485
|
+
print("\n\033[93m⚠\033[0m Configuration cancelled")
|
|
486
|
+
return None
|
|
487
|
+
|
|
488
|
+
def _get_output_name(self, base_model: str) -> Optional[str]:
|
|
489
|
+
"""Get the output model name from user."""
|
|
490
|
+
width = min(self.get_terminal_width() - 2, 70)
|
|
491
|
+
|
|
492
|
+
default_name = f"{base_model}-finetuned"
|
|
493
|
+
|
|
494
|
+
# Show output name dialog
|
|
495
|
+
print()
|
|
496
|
+
self.cli.print_box_header("Output Model", width)
|
|
497
|
+
self.cli.print_empty_line(width)
|
|
498
|
+
self.cli.print_box_line(f" Enter name for fine-tuned model:", width)
|
|
499
|
+
self.cli.print_box_line(f" \033[2mDefault: {default_name}\033[0m", width)
|
|
500
|
+
self.cli.print_empty_line(width)
|
|
501
|
+
self.cli.print_box_footer(width)
|
|
502
|
+
|
|
503
|
+
name = input(f"\n\033[96m▶\033[0m Model name \033[2m[{default_name}]\033[0m: ").strip()
|
|
504
|
+
name = name if name else default_name
|
|
505
|
+
|
|
506
|
+
# Check if name already exists
|
|
507
|
+
existing_models = self._get_available_models()
|
|
508
|
+
if any(model_name == name for model_name, _ in existing_models):
|
|
509
|
+
choice = input(f"\n\033[93m⚠\033[0m Model '{name}' exists. Overwrite? (\033[2my\033[0m/\033[93mN\033[0m): ").strip().lower()
|
|
510
|
+
if choice != 'y':
|
|
511
|
+
return None
|
|
512
|
+
|
|
513
|
+
return name
|
|
514
|
+
|
|
515
|
+
def _confirm_settings(self, base_model: str, dataset_path: Path,
|
|
516
|
+
config: TrainingConfig, output_name: str) -> bool:
|
|
517
|
+
"""Show summary and confirm settings."""
|
|
518
|
+
width = min(self.get_terminal_width() - 2, 70)
|
|
519
|
+
|
|
520
|
+
# Count dataset examples
|
|
521
|
+
example_count = sum(1 for _ in open(dataset_path))
|
|
522
|
+
|
|
523
|
+
# Estimate training time
|
|
524
|
+
estimated_time = self._estimate_training_time(example_count, config)
|
|
525
|
+
|
|
526
|
+
# Show summary dialog
|
|
527
|
+
print()
|
|
528
|
+
self.cli.print_box_header("Training Summary", width)
|
|
529
|
+
self.cli.print_empty_line(width)
|
|
530
|
+
|
|
531
|
+
self.cli.print_box_line(" \033[96mConfiguration:\033[0m", width)
|
|
532
|
+
self.cli.print_empty_line(width)
|
|
533
|
+
|
|
534
|
+
self.cli.print_box_line(f" Base model: \033[93m{base_model}\033[0m", width)
|
|
535
|
+
self.cli.print_box_line(f" Output model: \033[93m{output_name}\033[0m", width)
|
|
536
|
+
self.cli.print_box_line(f" Dataset: {dataset_path.name} \033[2m({example_count} examples)\033[0m", width)
|
|
537
|
+
|
|
538
|
+
self.cli.print_empty_line(width)
|
|
539
|
+
|
|
540
|
+
self.cli.print_box_line(f" Model size: \033[93m{config.model_size_category.title()}\033[0m ({config.estimated_parameters_b:.1f}B params)", width)
|
|
541
|
+
self.cli.print_box_line(f" Task type: {config.task_type.title()}", width)
|
|
542
|
+
self.cli.print_box_line(f" Config source: {config.configuration_source.replace('_', ' ').title()}", width)
|
|
543
|
+
|
|
544
|
+
self.cli.print_empty_line(width)
|
|
545
|
+
|
|
546
|
+
self.cli.print_box_line(f" Epochs: {config.epochs}", width)
|
|
547
|
+
self.cli.print_box_line(f" Learning rate: {config.learning_rate:.1e}", width)
|
|
548
|
+
self.cli.print_box_line(f" LoRA rank: {config.lora_r}", width)
|
|
549
|
+
self.cli.print_box_line(f" Batch size: {config.batch_size} (x{config.gradient_accumulation_steps} acc.)", width)
|
|
550
|
+
if config.quantization_bits:
|
|
551
|
+
self.cli.print_box_line(f" Quantization: {config.quantization_bits}-bit", width)
|
|
552
|
+
|
|
553
|
+
self.cli.print_empty_line(width)
|
|
554
|
+
self.cli.print_box_line(f" \033[2mEstimated time: {estimated_time}\033[0m", width)
|
|
555
|
+
|
|
556
|
+
self.cli.print_empty_line(width)
|
|
557
|
+
self.cli.print_box_separator(width)
|
|
558
|
+
self.cli.print_empty_line(width)
|
|
559
|
+
|
|
560
|
+
self.cli.print_box_line(" Start fine-tuning?", width)
|
|
561
|
+
self.cli.print_empty_line(width)
|
|
562
|
+
self.cli.print_box_line(" \033[93m[Y]\033[0m Yes, start training", width)
|
|
563
|
+
self.cli.print_box_line(" \033[93m[N]\033[0m No, cancel", width)
|
|
564
|
+
|
|
565
|
+
self.cli.print_empty_line(width)
|
|
566
|
+
self.cli.print_box_footer(width)
|
|
567
|
+
|
|
568
|
+
choice = input("\n\033[96m▶\033[0m Choice (\033[93my\033[0m/\033[2mn\033[0m): ").strip().lower()
|
|
569
|
+
return choice in ['y', 'yes', '']
|
|
570
|
+
|
|
571
|
+
def _run_training(self, base_model: str, dataset_path: Path,
|
|
572
|
+
config: TrainingConfig, output_name: str) -> bool:
|
|
573
|
+
"""Run the actual training."""
|
|
574
|
+
print("\n\033[96m⚡\033[0m Starting fine-tuning...")
|
|
575
|
+
|
|
576
|
+
try:
|
|
577
|
+
# Hard requirement: MLX must be available for fine-tuning.
|
|
578
|
+
if not MLXLoRATrainer.is_available():
|
|
579
|
+
print("\n\033[31m✗\033[0m Fine-tuning requires MLX/Metal, but MLX is not available in this environment.")
|
|
580
|
+
return False
|
|
581
|
+
# Use MLXLoRATrainer for proper LoRA implementation
|
|
582
|
+
self.trainer = MLXLoRATrainer(self.model_manager, self.config)
|
|
583
|
+
|
|
584
|
+
# Progress tracking
|
|
585
|
+
start_time = time.time()
|
|
586
|
+
last_update = start_time
|
|
587
|
+
|
|
588
|
+
def update_progress(epoch, step, loss):
|
|
589
|
+
nonlocal last_update
|
|
590
|
+
current_time = time.time()
|
|
591
|
+
|
|
592
|
+
# Update every 0.5 seconds
|
|
593
|
+
if current_time - last_update > 0.5:
|
|
594
|
+
elapsed = current_time - start_time
|
|
595
|
+
progress = ((epoch * 100) + min(step, 99)) / (config.epochs * 100)
|
|
596
|
+
|
|
597
|
+
# Create progress bar
|
|
598
|
+
bar_width = 30
|
|
599
|
+
filled = int(bar_width * progress)
|
|
600
|
+
bar = "█" * filled + "░" * (bar_width - filled)
|
|
601
|
+
|
|
602
|
+
# Print progress
|
|
603
|
+
sys.stdout.write(f"\r {bar} {progress*100:.0f}% | Epoch {epoch+1}/{config.epochs} | Loss: {loss:.4f}")
|
|
604
|
+
sys.stdout.flush()
|
|
605
|
+
last_update = current_time
|
|
606
|
+
|
|
607
|
+
# Run training
|
|
608
|
+
success = self.trainer.train(
|
|
609
|
+
base_model_name=base_model,
|
|
610
|
+
dataset_path=dataset_path,
|
|
611
|
+
output_name=output_name,
|
|
612
|
+
training_config=config,
|
|
613
|
+
progress_callback=update_progress
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
print() # New line after progress
|
|
617
|
+
|
|
618
|
+
if success:
|
|
619
|
+
print(f"\n\033[32m✓\033[0m Fine-tuning completed!")
|
|
620
|
+
|
|
621
|
+
# Show where the model was saved
|
|
622
|
+
mlx_path = Path.home() / ".cortex" / "mlx_models" / output_name
|
|
623
|
+
if mlx_path.exists():
|
|
624
|
+
print(f"\n\033[96m📍\033[0m Model saved to: \033[93m{mlx_path}\033[0m")
|
|
625
|
+
print(f"\n\033[96m💡\033[0m To load your fine-tuned model:")
|
|
626
|
+
print(f" \033[93m/model {mlx_path}\033[0m")
|
|
627
|
+
|
|
628
|
+
# Check if adapter weights exist
|
|
629
|
+
adapter_file = mlx_path / "adapter.safetensors"
|
|
630
|
+
if adapter_file.exists():
|
|
631
|
+
size_mb = adapter_file.stat().st_size / (1024 * 1024)
|
|
632
|
+
print(f"\n\033[2m LoRA adapter size: {size_mb:.1f} MB\033[0m")
|
|
633
|
+
print(f"\033[2m Base model: {base_model}\033[0m")
|
|
634
|
+
|
|
635
|
+
return True
|
|
636
|
+
else:
|
|
637
|
+
print("\n\033[31m✗\033[0m Fine-tuning failed")
|
|
638
|
+
return False
|
|
639
|
+
|
|
640
|
+
except KeyboardInterrupt:
|
|
641
|
+
print("\n\n\033[93m⚠\033[0m Training interrupted by user")
|
|
642
|
+
return False
|
|
643
|
+
except Exception as e:
|
|
644
|
+
logger.error(f"Training failed: {e}")
|
|
645
|
+
print(f"\n\n\033[31m✗\033[0m Training error: {e}")
|
|
646
|
+
import traceback
|
|
647
|
+
traceback.print_exc()
|
|
648
|
+
return False
|
|
649
|
+
|
|
650
|
+
def _get_available_models(self) -> List[Tuple[str, Dict[str, Any]]]:
|
|
651
|
+
"""Get list of available models."""
|
|
652
|
+
models = []
|
|
653
|
+
|
|
654
|
+
# Get models from model manager
|
|
655
|
+
discovered = self.model_manager.discover_available_models()
|
|
656
|
+
|
|
657
|
+
for model_info in discovered:
|
|
658
|
+
try:
|
|
659
|
+
name = model_info.get('name', 'Unknown')
|
|
660
|
+
info = {
|
|
661
|
+
'path': Path(model_info.get('path', '')),
|
|
662
|
+
'format': model_info.get('format', 'Unknown'),
|
|
663
|
+
'size_gb': model_info.get('size_gb', 0.0)
|
|
664
|
+
}
|
|
665
|
+
models.append((name, info))
|
|
666
|
+
except Exception as e:
|
|
667
|
+
logger.debug(f"Error processing model info: {e}")
|
|
668
|
+
continue
|
|
669
|
+
|
|
670
|
+
return sorted(models, key=lambda x: x[0])
|
|
671
|
+
|
|
672
|
+
def _get_model_info(self, model_name: str) -> Optional[Dict[str, Any]]:
|
|
673
|
+
"""Get information about a model."""
|
|
674
|
+
models = self._get_available_models()
|
|
675
|
+
for name, info in models:
|
|
676
|
+
if name == model_name:
|
|
677
|
+
return info
|
|
678
|
+
return None
|
|
679
|
+
|
|
680
|
+
def _estimate_training_time(self, example_count: int, config: TrainingConfig) -> str:
|
|
681
|
+
"""Estimate training time based on dataset size, epochs, and model characteristics."""
|
|
682
|
+
# Base time estimation adjusted for model size and batch settings
|
|
683
|
+
base_seconds_per_example = {
|
|
684
|
+
"tiny": 0.1, # Very fast for small models
|
|
685
|
+
"small": 0.3, # Fast
|
|
686
|
+
"medium": 0.7, # Standard
|
|
687
|
+
"large": 1.5, # Slower for large models
|
|
688
|
+
"xlarge": 3.0 # Much slower
|
|
689
|
+
}.get(config.model_size_category, 0.7)
|
|
690
|
+
|
|
691
|
+
# Adjust for gradient accumulation (more accumulation = fewer actual updates)
|
|
692
|
+
effective_batch_size = config.batch_size * config.gradient_accumulation_steps
|
|
693
|
+
batch_factor = max(0.5, 1.0 / (effective_batch_size ** 0.5)) # Larger batches are more efficient
|
|
694
|
+
|
|
695
|
+
# Adjust for quantization (if enabled, training is faster)
|
|
696
|
+
quant_factor = 0.7 if config.quantization_bits else 1.0
|
|
697
|
+
|
|
698
|
+
# Calculate total time
|
|
699
|
+
adjusted_time_per_example = base_seconds_per_example * batch_factor * quant_factor
|
|
700
|
+
total_seconds = example_count * config.epochs * adjusted_time_per_example
|
|
701
|
+
|
|
702
|
+
if total_seconds < 60:
|
|
703
|
+
return f"~{int(total_seconds)} seconds"
|
|
704
|
+
elif total_seconds < 3600:
|
|
705
|
+
return f"~{int(total_seconds / 60)} minutes"
|
|
706
|
+
else:
|
|
707
|
+
return f"~{total_seconds / 3600:.1f} hours"
|