cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
"""MLX LoRA trainer using mlx_lm implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional, Dict, Any, Callable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
import shutil
|
|
10
|
+
import math
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import mlx.core as mx
|
|
14
|
+
import mlx.nn as nn
|
|
15
|
+
from mlx_lm import load as mlx_load
|
|
16
|
+
from mlx_lm.tuner.utils import linear_to_lora_layers
|
|
17
|
+
from mlx_lm.tuner.datasets import load_dataset as mlx_load_dataset, CacheDataset
|
|
18
|
+
from mlx_lm.tuner.trainer import TrainingArgs, train, evaluate, TrainingCallback
|
|
19
|
+
import mlx.optimizers as optim
|
|
20
|
+
MLX_AVAILABLE = True
|
|
21
|
+
except Exception as exc: # noqa: BLE001
|
|
22
|
+
# Keep the module importable when MLX/metal is missing so we can show a clear message.
|
|
23
|
+
MLX_AVAILABLE = False
|
|
24
|
+
mx = nn = mlx_load = linear_to_lora_layers = mlx_load_dataset = CacheDataset = TrainingArgs = train = evaluate = TrainingCallback = optim = None # type: ignore
|
|
25
|
+
_MLX_IMPORT_ERROR = exc
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class LoRAConfig:
|
|
32
|
+
"""Configuration for LoRA fine-tuning."""
|
|
33
|
+
rank: int = 8
|
|
34
|
+
alpha: float = 16.0
|
|
35
|
+
dropout: float = 0.0
|
|
36
|
+
target_modules: list = None
|
|
37
|
+
|
|
38
|
+
def __post_init__(self):
|
|
39
|
+
if self.target_modules is None:
|
|
40
|
+
self.target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MLXLoRATrainer:
|
|
44
|
+
"""LoRA trainer using mlx_lm's implementation."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, model_manager, config):
|
|
47
|
+
"""Initialize the trainer."""
|
|
48
|
+
self.model_manager = model_manager
|
|
49
|
+
self.config = config
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def is_available() -> bool:
|
|
53
|
+
"""Return True when MLX/Metal stack is importable."""
|
|
54
|
+
return MLX_AVAILABLE
|
|
55
|
+
|
|
56
|
+
def train(
|
|
57
|
+
self,
|
|
58
|
+
base_model_name: str,
|
|
59
|
+
dataset_path: Path,
|
|
60
|
+
output_name: str,
|
|
61
|
+
training_config: Any,
|
|
62
|
+
progress_callback: Optional[Callable] = None
|
|
63
|
+
) -> bool:
|
|
64
|
+
"""
|
|
65
|
+
Train a model using LoRA with mlx_lm.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
base_model_name: Name of the base model to fine-tune
|
|
69
|
+
dataset_path: Path to the training dataset
|
|
70
|
+
output_name: Name for the fine-tuned model
|
|
71
|
+
training_config: Training configuration
|
|
72
|
+
progress_callback: Optional callback for progress updates
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
True if training succeeded, False otherwise
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
if not MLX_AVAILABLE:
|
|
79
|
+
logger.error("MLX is not available; fine-tuning requires MLX.")
|
|
80
|
+
if "_MLX_IMPORT_ERROR" in globals():
|
|
81
|
+
logger.debug(f"MLX import error: {_MLX_IMPORT_ERROR}") # type: ignore[name-defined]
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
logger.info(f"Starting MLX LoRA training: {base_model_name} -> {output_name}")
|
|
85
|
+
|
|
86
|
+
# Get the model path
|
|
87
|
+
model_path = self._get_model_path(base_model_name)
|
|
88
|
+
if not model_path:
|
|
89
|
+
logger.error(f"Could not find model path for {base_model_name}")
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
# Try to reuse an already loaded model to avoid a second full load in unified memory.
|
|
93
|
+
model = None
|
|
94
|
+
tokenizer = None
|
|
95
|
+
if self.model_manager:
|
|
96
|
+
cache_keys = [
|
|
97
|
+
str(model_path),
|
|
98
|
+
base_model_name,
|
|
99
|
+
getattr(self.model_manager, "current_model", None),
|
|
100
|
+
]
|
|
101
|
+
for key in cache_keys:
|
|
102
|
+
if not key:
|
|
103
|
+
continue
|
|
104
|
+
if model is None:
|
|
105
|
+
model = self.model_manager.model_cache.get(key)
|
|
106
|
+
if tokenizer is None:
|
|
107
|
+
tokenizer = self.model_manager.tokenizers.get(key)
|
|
108
|
+
if model is not None and tokenizer is not None:
|
|
109
|
+
logger.info(f"Reusing loaded model from cache (key: {key})")
|
|
110
|
+
break
|
|
111
|
+
|
|
112
|
+
# Load model and tokenizer using mlx_lm if not already cached
|
|
113
|
+
if model is None or tokenizer is None:
|
|
114
|
+
logger.info(f"Loading model from {model_path}")
|
|
115
|
+
model, tokenizer = mlx_load(str(model_path))
|
|
116
|
+
|
|
117
|
+
# Apply LoRA layers
|
|
118
|
+
logger.info(f"Applying LoRA with rank={training_config.lora_r}")
|
|
119
|
+
lora_config = LoRAConfig(
|
|
120
|
+
rank=training_config.lora_r,
|
|
121
|
+
alpha=training_config.lora_alpha,
|
|
122
|
+
dropout=training_config.lora_dropout,
|
|
123
|
+
target_modules=training_config.target_modules
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Convert linear layers to LoRA layers
|
|
127
|
+
# Note: linear_to_lora_layers modifies the model in-place and returns None
|
|
128
|
+
linear_to_lora_layers(
|
|
129
|
+
model,
|
|
130
|
+
num_layers=training_config.num_lora_layers if hasattr(training_config, 'num_lora_layers') else 16,
|
|
131
|
+
config={
|
|
132
|
+
"rank": lora_config.rank,
|
|
133
|
+
"dropout": lora_config.dropout,
|
|
134
|
+
"scale": lora_config.alpha / lora_config.rank
|
|
135
|
+
}
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Model freezing is handled automatically by linear_to_lora_layers
|
|
139
|
+
# Only LoRA parameters will be trainable
|
|
140
|
+
|
|
141
|
+
# Load dataset
|
|
142
|
+
logger.info(f"Loading dataset from {dataset_path}")
|
|
143
|
+
train_data = self._load_dataset(dataset_path, tokenizer, training_config)
|
|
144
|
+
|
|
145
|
+
if not train_data:
|
|
146
|
+
logger.error("Failed to load dataset")
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
# Get dataset length properly
|
|
150
|
+
if hasattr(train_data, '__len__'):
|
|
151
|
+
dataset_len = len(train_data)
|
|
152
|
+
elif hasattr(train_data, 'data') and hasattr(train_data.data, '__len__'):
|
|
153
|
+
dataset_len = len(train_data.data)
|
|
154
|
+
else:
|
|
155
|
+
dataset_len = 1 # Fallback
|
|
156
|
+
|
|
157
|
+
logger.info(f"Dataset contains {dataset_len} examples")
|
|
158
|
+
|
|
159
|
+
# Setup training arguments
|
|
160
|
+
adapter_file = str(Path.home() / ".cortex" / "adapters" / output_name / "adapter.safetensors")
|
|
161
|
+
Path(adapter_file).parent.mkdir(parents=True, exist_ok=True)
|
|
162
|
+
|
|
163
|
+
# Calculate iterations: total examples / (effective batch) * epochs
|
|
164
|
+
effective_batch = max(1, training_config.batch_size) * max(
|
|
165
|
+
1, getattr(training_config, "gradient_accumulation_steps", 1)
|
|
166
|
+
)
|
|
167
|
+
num_iters = max(1, math.ceil((dataset_len * training_config.epochs) / effective_batch))
|
|
168
|
+
|
|
169
|
+
training_args = TrainingArgs(
|
|
170
|
+
batch_size=training_config.batch_size,
|
|
171
|
+
iters=num_iters,
|
|
172
|
+
steps_per_report=10,
|
|
173
|
+
# Avoid extra evaluation passes for small datasets by setting eval steps beyond total iters
|
|
174
|
+
steps_per_eval=num_iters + 1,
|
|
175
|
+
val_batches=1, # Just 1 validation batch
|
|
176
|
+
steps_per_save=100,
|
|
177
|
+
adapter_file=adapter_file,
|
|
178
|
+
grad_checkpoint=training_config.gradient_checkpointing if hasattr(training_config, 'gradient_checkpointing') else False,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Setup optimizer with learning rate
|
|
182
|
+
optimizer = optim.AdamW(
|
|
183
|
+
learning_rate=training_config.learning_rate,
|
|
184
|
+
weight_decay=training_config.weight_decay if hasattr(training_config, 'weight_decay') else 0.01
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Create a simple progress tracker
|
|
188
|
+
class ProgressTracker(TrainingCallback):
|
|
189
|
+
def __init__(self, callback, dataset_len, epochs):
|
|
190
|
+
self.callback = callback
|
|
191
|
+
self.total_iters = training_args.iters
|
|
192
|
+
self.steps_per_epoch = max(1, dataset_len // training_args.batch_size)
|
|
193
|
+
self.epochs = epochs
|
|
194
|
+
|
|
195
|
+
def on_train_loss_report(self, train_info: dict):
|
|
196
|
+
"""Called when training loss is reported."""
|
|
197
|
+
if self.callback:
|
|
198
|
+
iteration = train_info.get('iteration', 0)
|
|
199
|
+
loss = train_info.get('train_loss', 0.0)
|
|
200
|
+
# MLX iterations start at 1, not 0, so adjust
|
|
201
|
+
actual_iter = iteration - 1
|
|
202
|
+
# Calculate epoch based on actual iteration
|
|
203
|
+
epoch = actual_iter // self.steps_per_epoch
|
|
204
|
+
step = actual_iter % self.steps_per_epoch
|
|
205
|
+
# Ensure epoch doesn't exceed total epochs
|
|
206
|
+
epoch = min(epoch, self.epochs - 1)
|
|
207
|
+
self.callback(epoch, step, loss)
|
|
208
|
+
|
|
209
|
+
tracker = ProgressTracker(progress_callback, dataset_len, training_config.epochs) if progress_callback else None
|
|
210
|
+
|
|
211
|
+
# Prepare validation dataset
|
|
212
|
+
# For MLX training, we always need a validation dataset (can't be None)
|
|
213
|
+
# For small datasets, we'll use the same data for validation
|
|
214
|
+
val_data = train_data # Default to using training data for validation
|
|
215
|
+
logger.info("Using training data for validation (small dataset)")
|
|
216
|
+
|
|
217
|
+
# Training loop
|
|
218
|
+
logger.info("Starting training...")
|
|
219
|
+
# Note: train() doesn't return anything, it modifies model in-place and saves weights
|
|
220
|
+
train(
|
|
221
|
+
model,
|
|
222
|
+
optimizer,
|
|
223
|
+
train_dataset=train_data,
|
|
224
|
+
val_dataset=val_data, # Use proper validation dataset or None
|
|
225
|
+
args=training_args,
|
|
226
|
+
training_callback=tracker
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Save the fine-tuned model
|
|
230
|
+
logger.info(f"Saving fine-tuned model to {output_name}")
|
|
231
|
+
adapter_dir = Path(training_args.adapter_file).parent
|
|
232
|
+
success = self._save_model(
|
|
233
|
+
model=model,
|
|
234
|
+
tokenizer=tokenizer,
|
|
235
|
+
output_name=output_name,
|
|
236
|
+
base_model_name=base_model_name,
|
|
237
|
+
adapter_path=str(adapter_dir)
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if success:
|
|
241
|
+
logger.info(f"Successfully saved fine-tuned model as {output_name}")
|
|
242
|
+
# Clean up training checkpoints after successful save
|
|
243
|
+
self._cleanup_checkpoints(adapter_dir)
|
|
244
|
+
return True
|
|
245
|
+
else:
|
|
246
|
+
logger.error("Failed to save fine-tuned model")
|
|
247
|
+
return False
|
|
248
|
+
|
|
249
|
+
except KeyboardInterrupt:
|
|
250
|
+
logger.info("Training interrupted by user")
|
|
251
|
+
print("\n\033[93m⚠\033[0m Training interrupted by user")
|
|
252
|
+
return False
|
|
253
|
+
except Exception as e:
|
|
254
|
+
logger.error(f"Training failed: {e}")
|
|
255
|
+
print(f"\n\033[31m✗\033[0m Training error: {str(e)}")
|
|
256
|
+
import traceback
|
|
257
|
+
traceback.print_exc()
|
|
258
|
+
return False
|
|
259
|
+
|
|
260
|
+
def _get_model_path(self, model_name: str) -> Optional[Path]:
|
|
261
|
+
"""Get the path to the model, prioritizing MLX models."""
|
|
262
|
+
# First check if it's already an MLX model (converted or fine-tuned)
|
|
263
|
+
mlx_path = Path.home() / ".cortex" / "mlx_models" / model_name
|
|
264
|
+
if mlx_path.exists():
|
|
265
|
+
logger.info(f"Found MLX model at: {mlx_path}")
|
|
266
|
+
return mlx_path
|
|
267
|
+
|
|
268
|
+
# Check in models directory
|
|
269
|
+
models_path = Path.home() / ".cortex" / "models" / model_name
|
|
270
|
+
if models_path.exists():
|
|
271
|
+
logger.info(f"Found model at: {models_path}")
|
|
272
|
+
return models_path
|
|
273
|
+
|
|
274
|
+
# Check in configured models directory (most common location)
|
|
275
|
+
if self.model_manager and self.model_manager.config:
|
|
276
|
+
try:
|
|
277
|
+
config_model_path = Path(self.model_manager.config.model.model_path).expanduser().resolve()
|
|
278
|
+
config_path = config_model_path / model_name
|
|
279
|
+
if config_path.exists():
|
|
280
|
+
logger.info(f"Found model in configured path: {config_path}")
|
|
281
|
+
return config_path
|
|
282
|
+
except Exception as e:
|
|
283
|
+
logger.debug(f"Could not check configured model path: {e}")
|
|
284
|
+
|
|
285
|
+
# Check if it's a full path
|
|
286
|
+
if Path(model_name).exists():
|
|
287
|
+
full_path = Path(model_name).resolve()
|
|
288
|
+
logger.info(f"Found model at full path: {full_path}")
|
|
289
|
+
return full_path
|
|
290
|
+
|
|
291
|
+
# Last resort: check if it's a relative path in current directory
|
|
292
|
+
current_path = Path.cwd() / model_name
|
|
293
|
+
if current_path.exists():
|
|
294
|
+
logger.info(f"Found model at current directory: {current_path}")
|
|
295
|
+
return current_path
|
|
296
|
+
|
|
297
|
+
logger.error(f"Model not found: {model_name}")
|
|
298
|
+
return None
|
|
299
|
+
|
|
300
|
+
def _load_dataset(self, dataset_path: Path, tokenizer: Any, training_config: Any) -> Optional[Any]:
|
|
301
|
+
"""Load and prepare the dataset."""
|
|
302
|
+
try:
|
|
303
|
+
from mlx_lm.tuner.datasets import CacheDataset, TextDataset
|
|
304
|
+
|
|
305
|
+
# Load JSONL dataset
|
|
306
|
+
examples = []
|
|
307
|
+
with open(dataset_path, 'r') as f:
|
|
308
|
+
for line in f:
|
|
309
|
+
data = json.loads(line.strip())
|
|
310
|
+
examples.append(data)
|
|
311
|
+
|
|
312
|
+
# Check data format and create appropriate dataset
|
|
313
|
+
if not examples:
|
|
314
|
+
logger.error("No examples found in dataset")
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
sample = examples[0]
|
|
318
|
+
|
|
319
|
+
# Convert all formats to text format for simplicity
|
|
320
|
+
# This avoids issues with tokenizers that don't have chat templates
|
|
321
|
+
text_examples = []
|
|
322
|
+
max_seq_len = getattr(training_config, "max_sequence_length", None)
|
|
323
|
+
# crude char-level guard to avoid very long sequences; token-level truncation happens in tokenizer
|
|
324
|
+
max_chars = max_seq_len * 4 if max_seq_len else None
|
|
325
|
+
for example in examples:
|
|
326
|
+
if 'prompt' in example and 'response' in example:
|
|
327
|
+
# Format as a simple conversation
|
|
328
|
+
text = f"User: {example['prompt']}\n\nAssistant: {example['response']}"
|
|
329
|
+
elif 'prompt' in example and 'completion' in example:
|
|
330
|
+
text = f"User: {example['prompt']}\n\nAssistant: {example['completion']}"
|
|
331
|
+
elif 'text' in example:
|
|
332
|
+
text = example['text']
|
|
333
|
+
else:
|
|
334
|
+
logger.warning(f"Skipping example with unsupported format: {example}")
|
|
335
|
+
continue
|
|
336
|
+
if max_chars and len(text) > max_chars:
|
|
337
|
+
text = text[:max_chars]
|
|
338
|
+
text_examples.append({'text': text})
|
|
339
|
+
|
|
340
|
+
if not text_examples:
|
|
341
|
+
logger.error("No valid examples found after conversion")
|
|
342
|
+
return None
|
|
343
|
+
|
|
344
|
+
# Create TextDataset which just uses tokenizer.encode()
|
|
345
|
+
dataset = TextDataset(
|
|
346
|
+
data=text_examples,
|
|
347
|
+
tokenizer=tokenizer,
|
|
348
|
+
text_key='text'
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# Wrap with CacheDataset for efficiency
|
|
352
|
+
cached_dataset = CacheDataset(dataset)
|
|
353
|
+
|
|
354
|
+
logger.info(f"Loaded {len(text_examples)} training examples")
|
|
355
|
+
return cached_dataset
|
|
356
|
+
|
|
357
|
+
except ImportError as e:
|
|
358
|
+
logger.error(f"Required dataset classes not available: {e}")
|
|
359
|
+
return None
|
|
360
|
+
except FileNotFoundError:
|
|
361
|
+
logger.error(f"Dataset file not found: {dataset_path}")
|
|
362
|
+
return None
|
|
363
|
+
except json.JSONDecodeError as e:
|
|
364
|
+
logger.error(f"Invalid JSON in dataset: {e}")
|
|
365
|
+
return None
|
|
366
|
+
except Exception as e:
|
|
367
|
+
logger.error(f"Error loading dataset: {e}")
|
|
368
|
+
import traceback
|
|
369
|
+
traceback.print_exc()
|
|
370
|
+
return None
|
|
371
|
+
|
|
372
|
+
def _save_model(
|
|
373
|
+
self,
|
|
374
|
+
model: Any,
|
|
375
|
+
tokenizer: Any,
|
|
376
|
+
output_name: str,
|
|
377
|
+
base_model_name: str,
|
|
378
|
+
adapter_path: str
|
|
379
|
+
) -> bool:
|
|
380
|
+
"""Save the fine-tuned model with integrated LoRA weights."""
|
|
381
|
+
try:
|
|
382
|
+
# Always save to MLX models directory for consistent loading
|
|
383
|
+
output_dir = Path.home() / ".cortex" / "mlx_models" / output_name
|
|
384
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
385
|
+
|
|
386
|
+
# Get base model path
|
|
387
|
+
base_model_path = self._get_model_path(base_model_name)
|
|
388
|
+
if not base_model_path or not base_model_path.exists():
|
|
389
|
+
logger.error(f"Base model path not found: {base_model_name}")
|
|
390
|
+
return False
|
|
391
|
+
|
|
392
|
+
logger.info(f"Saving fine-tuned model to {output_dir}")
|
|
393
|
+
|
|
394
|
+
# Copy base model files and add adapter
|
|
395
|
+
# Note: mlx_lm doesn't have a save function, the adapter is saved separately by train()
|
|
396
|
+
logger.info(f"Copying base model files from {base_model_path} to {output_dir}")
|
|
397
|
+
for file in base_model_path.glob("*"):
|
|
398
|
+
if file.is_file():
|
|
399
|
+
shutil.copy2(file, output_dir / file.name)
|
|
400
|
+
elif file.is_dir():
|
|
401
|
+
shutil.copytree(file, output_dir / file.name, dirs_exist_ok=True)
|
|
402
|
+
|
|
403
|
+
# Copy adapter files (only the final adapter, not checkpoints)
|
|
404
|
+
adapter_path = Path(adapter_path)
|
|
405
|
+
if adapter_path.exists():
|
|
406
|
+
for adapter_file in adapter_path.glob("*.safetensors"):
|
|
407
|
+
# Skip checkpoint files (e.g., 0000100_adapters.safetensors)
|
|
408
|
+
if adapter_file.name.endswith('_adapters.safetensors'):
|
|
409
|
+
logger.debug(f"Skipping checkpoint: {adapter_file.name}")
|
|
410
|
+
continue
|
|
411
|
+
logger.info(f"Copying adapter: {adapter_file.name}")
|
|
412
|
+
shutil.copy2(adapter_file, output_dir / adapter_file.name)
|
|
413
|
+
|
|
414
|
+
if (adapter_path / "adapter_config.json").exists():
|
|
415
|
+
shutil.copy2(adapter_path / "adapter_config.json", output_dir / "adapter_config.json")
|
|
416
|
+
|
|
417
|
+
# Update config to mark as fine-tuned
|
|
418
|
+
config_path = output_dir / "config.json"
|
|
419
|
+
if config_path.exists():
|
|
420
|
+
with open(config_path, 'r') as f:
|
|
421
|
+
config = json.load(f)
|
|
422
|
+
|
|
423
|
+
# Add fine-tuning metadata
|
|
424
|
+
config["fine_tuned"] = True
|
|
425
|
+
config["base_model"] = base_model_name
|
|
426
|
+
config["fine_tuning_method"] = "LoRA"
|
|
427
|
+
config["lora_adapter"] = True
|
|
428
|
+
config["created_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
429
|
+
|
|
430
|
+
with open(config_path, 'w') as f:
|
|
431
|
+
json.dump(config, f, indent=2)
|
|
432
|
+
|
|
433
|
+
# Create a marker file for proper detection
|
|
434
|
+
with open(output_dir / "fine_tuned.marker", 'w') as f:
|
|
435
|
+
f.write(f"LoRA fine-tuned version of {base_model_name}\n")
|
|
436
|
+
f.write(f"Created: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
437
|
+
f.write(f"Adapter path: {adapter_path}\n")
|
|
438
|
+
f.write(f"Output directory: {output_dir}\n")
|
|
439
|
+
|
|
440
|
+
logger.info(f"Fine-tuned model successfully saved to {output_dir}")
|
|
441
|
+
return True
|
|
442
|
+
|
|
443
|
+
except Exception as e:
|
|
444
|
+
logger.error(f"Error saving model: {e}")
|
|
445
|
+
import traceback
|
|
446
|
+
traceback.print_exc()
|
|
447
|
+
return False
|
|
448
|
+
|
|
449
|
+
def _cleanup_checkpoints(self, adapter_dir: Path) -> None:
|
|
450
|
+
"""
|
|
451
|
+
Clean up training checkpoint files after successful training.
|
|
452
|
+
|
|
453
|
+
Checkpoints are intermediate saves during training (e.g., 0000100_adapters.safetensors).
|
|
454
|
+
We keep them during training for crash recovery but delete after successful completion.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
adapter_dir: Directory containing adapter files and checkpoints
|
|
458
|
+
"""
|
|
459
|
+
try:
|
|
460
|
+
if not adapter_dir.exists():
|
|
461
|
+
return
|
|
462
|
+
|
|
463
|
+
checkpoint_files = []
|
|
464
|
+
total_size = 0
|
|
465
|
+
|
|
466
|
+
# Find all checkpoint files (pattern: NNNNNNN_adapters.safetensors)
|
|
467
|
+
for file in adapter_dir.glob("*_adapters.safetensors"):
|
|
468
|
+
# Check if filename matches checkpoint pattern (digits followed by _adapters.safetensors)
|
|
469
|
+
filename = file.name
|
|
470
|
+
if filename.endswith("_adapters.safetensors"):
|
|
471
|
+
# Extract the prefix before _adapters
|
|
472
|
+
prefix = filename[:-len("_adapters.safetensors")]
|
|
473
|
+
# Check if prefix is all digits (checkpoint pattern)
|
|
474
|
+
if prefix.isdigit():
|
|
475
|
+
checkpoint_files.append(file)
|
|
476
|
+
total_size += file.stat().st_size
|
|
477
|
+
|
|
478
|
+
if checkpoint_files:
|
|
479
|
+
# Convert size to human-readable format
|
|
480
|
+
size_gb = total_size / (1024 ** 3)
|
|
481
|
+
size_str = f"{size_gb:.2f}GB" if size_gb >= 1 else f"{total_size / (1024 ** 2):.1f}MB"
|
|
482
|
+
|
|
483
|
+
logger.info(f"Cleaning up {len(checkpoint_files)} training checkpoints ({size_str})")
|
|
484
|
+
|
|
485
|
+
# Delete checkpoint files
|
|
486
|
+
for checkpoint in checkpoint_files:
|
|
487
|
+
try:
|
|
488
|
+
checkpoint.unlink()
|
|
489
|
+
logger.debug(f"Deleted checkpoint: {checkpoint.name}")
|
|
490
|
+
except Exception as e:
|
|
491
|
+
logger.warning(f"Failed to delete checkpoint {checkpoint.name}: {e}")
|
|
492
|
+
|
|
493
|
+
logger.info(f"✓ Freed {size_str} by removing training checkpoints")
|
|
494
|
+
print(f"\033[92m✓\033[0m Cleaned up {len(checkpoint_files)} training checkpoints ({size_str})")
|
|
495
|
+
else:
|
|
496
|
+
logger.debug("No checkpoint files to clean up")
|
|
497
|
+
|
|
498
|
+
except Exception as e:
|
|
499
|
+
# Don't fail the training if cleanup fails, just log the error
|
|
500
|
+
logger.warning(f"Checkpoint cleanup failed (non-critical): {e}")
|
|
501
|
+
# Still inform user that training succeeded but cleanup had issues
|
|
502
|
+
print(f"\033[93m⚠\033[0m Training succeeded but checkpoint cleanup encountered issues: {e}")
|