langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/cli.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
1
|
+
"""
|
|
2
|
+
cli.py: Command-line interface for Langtune
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import logging
|
|
9
|
+
import torch
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
from .config import Config, load_config, save_config, get_preset_config, validate_config
|
|
14
|
+
from .trainer import create_trainer
|
|
15
|
+
from .data import load_dataset_from_config, create_data_loader, DataCollator
|
|
16
|
+
from .models import LoRALanguageModel
|
|
17
|
+
from .auth import (
|
|
18
|
+
get_api_key, verify_api_key, check_usage, interactive_login, logout,
|
|
19
|
+
print_usage_info, AuthenticationError, UsageLimitError, require_auth
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# Try to import rich for beautiful output
|
|
23
|
+
try:
|
|
24
|
+
from rich.console import Console
|
|
25
|
+
from rich.panel import Panel
|
|
26
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
|
27
|
+
from rich.table import Table
|
|
28
|
+
from rich import box
|
|
29
|
+
from rich.text import Text
|
|
30
|
+
RICH_AVAILABLE = True
|
|
31
|
+
console = Console()
|
|
32
|
+
except ImportError:
|
|
33
|
+
RICH_AVAILABLE = False
|
|
34
|
+
console = None
|
|
35
|
+
|
|
36
|
+
# Version
|
|
37
|
+
__version__ = "0.1.2"
|
|
38
|
+
|
|
39
|
+
# Setup logging
|
|
40
|
+
logging.basicConfig(
|
|
41
|
+
level=logging.INFO,
|
|
42
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
43
|
+
)
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
def _check_auth():
|
|
47
|
+
"""Check authentication before running protected commands."""
|
|
48
|
+
api_key = get_api_key()
|
|
49
|
+
|
|
50
|
+
if not api_key:
|
|
51
|
+
if RICH_AVAILABLE:
|
|
52
|
+
console.print("\n[bold red]š Authentication Required[/]\n")
|
|
53
|
+
console.print("Langtune requires an API key to run. Get your free key at:")
|
|
54
|
+
console.print("[blue underline]https://app.langtrain.xyz[/]\n")
|
|
55
|
+
console.print("Then authenticate with: [cyan]langtune auth login[/]\n")
|
|
56
|
+
else:
|
|
57
|
+
print("\nš Authentication Required\n")
|
|
58
|
+
print("Get your API key at: https://app.langtrain.xyz")
|
|
59
|
+
print("Then run: langtune auth login\n")
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
usage = check_usage(api_key)
|
|
64
|
+
if RICH_AVAILABLE:
|
|
65
|
+
remaining = f"{usage['tokens_remaining']:,}"
|
|
66
|
+
console.print(f"[dim]Tokens remaining: {remaining}[/]")
|
|
67
|
+
return True
|
|
68
|
+
except AuthenticationError as e:
|
|
69
|
+
if RICH_AVAILABLE:
|
|
70
|
+
console.print(f"[red]ā {e}[/]")
|
|
71
|
+
else:
|
|
72
|
+
print(f"ā {e}")
|
|
73
|
+
return False
|
|
74
|
+
except UsageLimitError as e:
|
|
75
|
+
if RICH_AVAILABLE:
|
|
76
|
+
console.print(f"[yellow]ā ļø {e}[/]")
|
|
77
|
+
else:
|
|
78
|
+
print(f"ā ļø {e}")
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def train_command(args):
|
|
83
|
+
"""Handle the train command."""
|
|
84
|
+
# Check authentication first
|
|
85
|
+
if not _check_auth():
|
|
86
|
+
return 1
|
|
87
|
+
|
|
88
|
+
logger.info("Starting training...")
|
|
89
|
+
|
|
90
|
+
# Load configuration
|
|
91
|
+
if args.config:
|
|
92
|
+
config = load_config(args.config)
|
|
93
|
+
elif args.preset:
|
|
94
|
+
config = get_preset_config(args.preset)
|
|
95
|
+
else:
|
|
96
|
+
logger.error("Either --config or --preset must be specified")
|
|
97
|
+
return 1
|
|
98
|
+
|
|
99
|
+
# Override config with command line arguments
|
|
100
|
+
if args.train_file:
|
|
101
|
+
config.data.train_file = args.train_file
|
|
102
|
+
if args.eval_file:
|
|
103
|
+
config.data.eval_file = args.eval_file
|
|
104
|
+
if args.output_dir:
|
|
105
|
+
config.output_dir = args.output_dir
|
|
106
|
+
if args.batch_size:
|
|
107
|
+
config.training.batch_size = args.batch_size
|
|
108
|
+
if args.learning_rate:
|
|
109
|
+
config.training.learning_rate = args.learning_rate
|
|
110
|
+
if args.epochs:
|
|
111
|
+
config.training.num_epochs = args.epochs
|
|
112
|
+
|
|
113
|
+
# Validate configuration
|
|
114
|
+
try:
|
|
115
|
+
validate_config(config)
|
|
116
|
+
except ValueError as e:
|
|
117
|
+
logger.error(f"Configuration validation failed: {e}")
|
|
118
|
+
return 1
|
|
119
|
+
|
|
120
|
+
# Create output directory
|
|
121
|
+
os.makedirs(config.output_dir, exist_ok=True)
|
|
122
|
+
|
|
123
|
+
# Save configuration
|
|
124
|
+
config_path = os.path.join(config.output_dir, "config.yaml")
|
|
125
|
+
save_config(config, config_path)
|
|
126
|
+
logger.info(f"Configuration saved to {config_path}")
|
|
127
|
+
|
|
128
|
+
# Load datasets
|
|
129
|
+
try:
|
|
130
|
+
train_dataset, val_dataset, test_dataset = load_dataset_from_config(config)
|
|
131
|
+
logger.info(f"Loaded datasets: {len(train_dataset)} train, {len(val_dataset)} val, {len(test_dataset)} test")
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.error(f"Failed to load datasets: {e}")
|
|
134
|
+
return 1
|
|
135
|
+
|
|
136
|
+
# Create data loaders
|
|
137
|
+
collate_fn = DataCollator()
|
|
138
|
+
|
|
139
|
+
train_dataloader = create_data_loader(
|
|
140
|
+
train_dataset,
|
|
141
|
+
batch_size=config.training.batch_size,
|
|
142
|
+
shuffle=True,
|
|
143
|
+
num_workers=config.num_workers,
|
|
144
|
+
pin_memory=config.pin_memory,
|
|
145
|
+
collate_fn=collate_fn
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
val_dataloader = create_data_loader(
|
|
149
|
+
val_dataset,
|
|
150
|
+
batch_size=config.training.batch_size,
|
|
151
|
+
shuffle=False,
|
|
152
|
+
num_workers=config.num_workers,
|
|
153
|
+
pin_memory=config.pin_memory,
|
|
154
|
+
collate_fn=collate_fn
|
|
155
|
+
) if val_dataset else None
|
|
156
|
+
|
|
157
|
+
test_dataloader = create_data_loader(
|
|
158
|
+
test_dataset,
|
|
159
|
+
batch_size=config.training.batch_size,
|
|
160
|
+
shuffle=False,
|
|
161
|
+
num_workers=config.num_workers,
|
|
162
|
+
pin_memory=config.pin_memory,
|
|
163
|
+
collate_fn=collate_fn
|
|
164
|
+
) if test_dataset else None
|
|
165
|
+
|
|
166
|
+
# Create trainer
|
|
167
|
+
try:
|
|
168
|
+
trainer = create_trainer(
|
|
169
|
+
config=config,
|
|
170
|
+
train_dataloader=train_dataloader,
|
|
171
|
+
val_dataloader=val_dataloader,
|
|
172
|
+
test_dataloader=test_dataloader
|
|
173
|
+
)
|
|
174
|
+
except Exception as e:
|
|
175
|
+
logger.error(f"Failed to create trainer: {e}")
|
|
176
|
+
return 1
|
|
177
|
+
|
|
178
|
+
# Start training
|
|
179
|
+
try:
|
|
180
|
+
trainer.train(resume_from_checkpoint=args.resume_from)
|
|
181
|
+
logger.info("Training completed successfully!")
|
|
182
|
+
return 0
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.error(f"Training failed: {e}")
|
|
185
|
+
return 1
|
|
186
|
+
|
|
187
|
+
def evaluate_command(args):
|
|
188
|
+
"""Handle the evaluate command."""
|
|
189
|
+
# Check authentication first
|
|
190
|
+
if not _check_auth():
|
|
191
|
+
return 1
|
|
192
|
+
|
|
193
|
+
logger.info("Starting evaluation...")
|
|
194
|
+
|
|
195
|
+
if not args.model_path:
|
|
196
|
+
logger.error("--model_path is required for evaluation")
|
|
197
|
+
return 1
|
|
198
|
+
|
|
199
|
+
# Load configuration
|
|
200
|
+
if args.config:
|
|
201
|
+
config = load_config(args.config)
|
|
202
|
+
else:
|
|
203
|
+
logger.error("--config is required for evaluation")
|
|
204
|
+
return 1
|
|
205
|
+
|
|
206
|
+
# Load model
|
|
207
|
+
try:
|
|
208
|
+
model = LoRALanguageModel(
|
|
209
|
+
vocab_size=config.model.vocab_size,
|
|
210
|
+
embed_dim=config.model.embed_dim,
|
|
211
|
+
num_layers=config.model.num_layers,
|
|
212
|
+
num_heads=config.model.num_heads,
|
|
213
|
+
max_seq_len=config.model.max_seq_len,
|
|
214
|
+
mlp_ratio=config.model.mlp_ratio,
|
|
215
|
+
dropout=config.model.dropout,
|
|
216
|
+
lora_config=config.model.lora.__dict__ if config.model.lora else None
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
checkpoint = torch.load(args.model_path, map_location="cpu")
|
|
220
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
221
|
+
model.eval()
|
|
222
|
+
|
|
223
|
+
logger.info(f"Model loaded from {args.model_path}")
|
|
224
|
+
except Exception as e:
|
|
225
|
+
logger.error(f"Failed to load model: {e}")
|
|
226
|
+
return 1
|
|
227
|
+
|
|
228
|
+
# Load test dataset
|
|
229
|
+
try:
|
|
230
|
+
_, _, test_dataset = load_dataset_from_config(config)
|
|
231
|
+
test_dataloader = create_data_loader(
|
|
232
|
+
test_dataset,
|
|
233
|
+
batch_size=config.training.batch_size,
|
|
234
|
+
shuffle=False,
|
|
235
|
+
num_workers=config.num_workers,
|
|
236
|
+
pin_memory=config.pin_memory,
|
|
237
|
+
collate_fn=DataCollator()
|
|
238
|
+
)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logger.error(f"Failed to load test dataset: {e}")
|
|
241
|
+
return 1
|
|
242
|
+
|
|
243
|
+
# Evaluate
|
|
244
|
+
try:
|
|
245
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
246
|
+
model.to(device)
|
|
247
|
+
|
|
248
|
+
total_loss = 0.0
|
|
249
|
+
num_batches = 0
|
|
250
|
+
|
|
251
|
+
with torch.no_grad():
|
|
252
|
+
for batch in test_dataloader:
|
|
253
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
254
|
+
outputs = model(**batch)
|
|
255
|
+
total_loss += outputs["loss"].item()
|
|
256
|
+
num_batches += 1
|
|
257
|
+
|
|
258
|
+
avg_loss = total_loss / num_batches
|
|
259
|
+
logger.info(f"Test loss: {avg_loss:.4f}")
|
|
260
|
+
|
|
261
|
+
return 0
|
|
262
|
+
except Exception as e:
|
|
263
|
+
logger.error(f"Evaluation failed: {e}")
|
|
264
|
+
return 1
|
|
265
|
+
|
|
266
|
+
def generate_command(args):
|
|
267
|
+
"""Handle the generate command."""
|
|
268
|
+
# Check authentication first
|
|
269
|
+
if not _check_auth():
|
|
270
|
+
return 1
|
|
271
|
+
|
|
272
|
+
logger.info("Starting text generation...")
|
|
273
|
+
|
|
274
|
+
if not args.model_path:
|
|
275
|
+
logger.error("--model_path is required for generation")
|
|
276
|
+
return 1
|
|
277
|
+
|
|
278
|
+
# Load configuration
|
|
279
|
+
if args.config:
|
|
280
|
+
config = load_config(args.config)
|
|
281
|
+
else:
|
|
282
|
+
logger.error("--config is required for generation")
|
|
283
|
+
return 1
|
|
284
|
+
|
|
285
|
+
# Load model
|
|
286
|
+
try:
|
|
287
|
+
model = LoRALanguageModel(
|
|
288
|
+
vocab_size=config.model.vocab_size,
|
|
289
|
+
embed_dim=config.model.embed_dim,
|
|
290
|
+
num_layers=config.model.num_layers,
|
|
291
|
+
num_heads=config.model.num_heads,
|
|
292
|
+
max_seq_len=config.model.max_seq_len,
|
|
293
|
+
mlp_ratio=config.model.mlp_ratio,
|
|
294
|
+
dropout=config.model.dropout,
|
|
295
|
+
lora_config=config.model.lora.__dict__ if config.model.lora else None
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
checkpoint = torch.load(args.model_path, map_location="cpu")
|
|
299
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
300
|
+
model.eval()
|
|
301
|
+
|
|
302
|
+
logger.info(f"Model loaded from {args.model_path}")
|
|
303
|
+
except Exception as e:
|
|
304
|
+
logger.error(f"Failed to load model: {e}")
|
|
305
|
+
return 1
|
|
306
|
+
|
|
307
|
+
# Generate text
|
|
308
|
+
try:
|
|
309
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
310
|
+
model.to(device)
|
|
311
|
+
|
|
312
|
+
prompt = args.prompt or "The quick brown fox"
|
|
313
|
+
max_length = args.max_length or 100
|
|
314
|
+
|
|
315
|
+
# Simple tokenization
|
|
316
|
+
input_ids = torch.tensor([ord(c) for c in prompt[:50]], dtype=torch.long).unsqueeze(0).to(device)
|
|
317
|
+
|
|
318
|
+
with torch.no_grad():
|
|
319
|
+
generated = model.generate(
|
|
320
|
+
input_ids,
|
|
321
|
+
max_length=max_length,
|
|
322
|
+
temperature=args.temperature,
|
|
323
|
+
top_k=args.top_k,
|
|
324
|
+
top_p=args.top_p
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Simple decoding
|
|
328
|
+
generated_text = "".join([chr(i) for i in generated[0].cpu().tolist()])
|
|
329
|
+
|
|
330
|
+
print(f"Prompt: {prompt}")
|
|
331
|
+
print(f"Generated: {generated_text}")
|
|
332
|
+
|
|
333
|
+
return 0
|
|
334
|
+
except Exception as e:
|
|
335
|
+
logger.error(f"Generation failed: {e}")
|
|
336
|
+
return 1
|
|
337
|
+
|
|
338
|
+
def concept_command(args):
|
|
339
|
+
"""Handle the concept command."""
|
|
340
|
+
concept_name = args.concept.upper()
|
|
341
|
+
|
|
342
|
+
if RICH_AVAILABLE:
|
|
343
|
+
console.print(f"\n[bold cyan]š§Ŗ Running concept demonstration:[/] [bold magenta]{concept_name}[/]\n")
|
|
344
|
+
else:
|
|
345
|
+
logger.info(f"Running concept demonstration: {concept_name}")
|
|
346
|
+
|
|
347
|
+
# Simulate concept execution with rich progress
|
|
348
|
+
import time
|
|
349
|
+
|
|
350
|
+
if RICH_AVAILABLE:
|
|
351
|
+
with Progress(
|
|
352
|
+
SpinnerColumn(),
|
|
353
|
+
TextColumn("[progress.description]{task.description}"),
|
|
354
|
+
BarColumn(),
|
|
355
|
+
TaskProgressColumn(),
|
|
356
|
+
console=console
|
|
357
|
+
) as progress:
|
|
358
|
+
task = progress.add_task(f"[cyan]Processing {concept_name}...", total=100)
|
|
359
|
+
for i in range(100):
|
|
360
|
+
time.sleep(0.02)
|
|
361
|
+
progress.update(task, advance=1)
|
|
362
|
+
|
|
363
|
+
console.print(f"\n[bold green]ā[/] {concept_name} demonstration completed!\n")
|
|
364
|
+
else:
|
|
365
|
+
from tqdm import tqdm
|
|
366
|
+
for i in tqdm(range(100), desc=f"Progress for {concept_name}"):
|
|
367
|
+
time.sleep(0.02)
|
|
368
|
+
logger.info(f"{concept_name} demonstration completed!")
|
|
369
|
+
|
|
370
|
+
return 0
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _check_tpu() -> bool:
|
|
374
|
+
"""Check if Google TPU is available via torch_xla."""
|
|
375
|
+
try:
|
|
376
|
+
import torch_xla
|
|
377
|
+
import torch_xla.core.xla_model as xm
|
|
378
|
+
device = xm.xla_device()
|
|
379
|
+
return "TPU" in str(device) or "xla" in str(device).lower()
|
|
380
|
+
except:
|
|
381
|
+
return False
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _get_tpu_info() -> str:
|
|
385
|
+
"""Get TPU information if available."""
|
|
386
|
+
try:
|
|
387
|
+
import os
|
|
388
|
+
import torch_xla.core.xla_model as xm
|
|
389
|
+
tpu_name = os.environ.get("TPU_NAME", "")
|
|
390
|
+
tpu_cores = xm.xrt_world_size()
|
|
391
|
+
|
|
392
|
+
# Detect version
|
|
393
|
+
if "v4" in tpu_name.lower():
|
|
394
|
+
version = "v4"
|
|
395
|
+
elif "v3" in tpu_name.lower():
|
|
396
|
+
version = "v3"
|
|
397
|
+
elif "v2" in tpu_name.lower():
|
|
398
|
+
version = "v2"
|
|
399
|
+
else:
|
|
400
|
+
version = ""
|
|
401
|
+
|
|
402
|
+
return f"{version} ({tpu_cores} cores)"
|
|
403
|
+
except:
|
|
404
|
+
return "(available)"
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def version_command(args):
|
|
408
|
+
"""Handle the version command."""
|
|
409
|
+
if RICH_AVAILABLE:
|
|
410
|
+
# Check accelerator availability with detailed info
|
|
411
|
+
accelerator_type = "None"
|
|
412
|
+
|
|
413
|
+
# Check for NVIDIA CUDA
|
|
414
|
+
if torch.cuda.is_available():
|
|
415
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
416
|
+
gpu_count = torch.cuda.device_count()
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
|
420
|
+
cuda_version = torch.version.cuda
|
|
421
|
+
|
|
422
|
+
if gpu_count > 1:
|
|
423
|
+
gpu_info = f"[green]ā NVIDIA {gpu_name} Ć {gpu_count} ({gpu_memory:.0f}GB each)[/]"
|
|
424
|
+
else:
|
|
425
|
+
gpu_info = f"[green]ā NVIDIA {gpu_name} ({gpu_memory:.0f}GB)[/]"
|
|
426
|
+
|
|
427
|
+
accelerator_type = f"CUDA {cuda_version}"
|
|
428
|
+
except:
|
|
429
|
+
gpu_info = f"[green]ā NVIDIA {gpu_name}[/]"
|
|
430
|
+
accelerator_type = "CUDA"
|
|
431
|
+
# Check for Google TPU
|
|
432
|
+
elif _check_tpu():
|
|
433
|
+
tpu_info = _get_tpu_info()
|
|
434
|
+
gpu_info = f"[green]ā Google TPU {tpu_info}[/]"
|
|
435
|
+
accelerator_type = "TPU (torch_xla)"
|
|
436
|
+
# Check for Apple MPS
|
|
437
|
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
438
|
+
gpu_info = "[green]ā Apple Metal Performance Shaders (MPS)[/]"
|
|
439
|
+
accelerator_type = "Metal"
|
|
440
|
+
else:
|
|
441
|
+
gpu_info = "[yellow]ā Not available (CPU mode)[/]"
|
|
442
|
+
accelerator_type = "CPU"
|
|
443
|
+
|
|
444
|
+
table = Table(title="Langtune System Info", box=box.ROUNDED, title_style="bold magenta")
|
|
445
|
+
table.add_column("Component", style="cyan", no_wrap=True)
|
|
446
|
+
table.add_column("Value", style="white")
|
|
447
|
+
|
|
448
|
+
table.add_row("Langtune Version", f"v{__version__}")
|
|
449
|
+
table.add_row("Python Version", f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
|
|
450
|
+
table.add_row("PyTorch Version", torch.__version__)
|
|
451
|
+
table.add_row("Accelerator", gpu_info)
|
|
452
|
+
table.add_row("Backend", accelerator_type)
|
|
453
|
+
|
|
454
|
+
console.print()
|
|
455
|
+
console.print(table)
|
|
456
|
+
console.print()
|
|
457
|
+
else:
|
|
458
|
+
print(f"Langtune v{__version__}")
|
|
459
|
+
print(f"Python {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
|
|
460
|
+
print(f"PyTorch {torch.__version__}")
|
|
461
|
+
print(f"CUDA: {'Available' if torch.cuda.is_available() else 'Not available'}")
|
|
462
|
+
print(f"TPU: {'Available' if _check_tpu() else 'Not available'}")
|
|
463
|
+
|
|
464
|
+
return 0
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def info_command(args):
|
|
468
|
+
"""Handle the info command - show quick start guide."""
|
|
469
|
+
if RICH_AVAILABLE:
|
|
470
|
+
console.print()
|
|
471
|
+
|
|
472
|
+
# Quick start panel
|
|
473
|
+
quick_start = Text()
|
|
474
|
+
quick_start.append("1. Prepare your data\n", style="bold cyan")
|
|
475
|
+
quick_start.append(" Place your training text in a .txt or .json file\n\n")
|
|
476
|
+
quick_start.append("2. Start training\n", style="bold cyan")
|
|
477
|
+
quick_start.append(" langtune train --preset small --train-file data.txt\n\n", style="green")
|
|
478
|
+
quick_start.append("3. Evaluate your model\n", style="bold cyan")
|
|
479
|
+
quick_start.append(" langtune evaluate --config config.yaml --model-path model.pt\n\n", style="green")
|
|
480
|
+
quick_start.append("4. Generate text\n", style="bold cyan")
|
|
481
|
+
quick_start.append(" langtune generate --config config.yaml --model-path model.pt --prompt \"Hello\"\n", style="green")
|
|
482
|
+
|
|
483
|
+
panel = Panel(
|
|
484
|
+
quick_start,
|
|
485
|
+
title="[bold]š Quick Start Guide[/]",
|
|
486
|
+
border_style="cyan",
|
|
487
|
+
box=box.ROUNDED
|
|
488
|
+
)
|
|
489
|
+
console.print(panel)
|
|
490
|
+
|
|
491
|
+
# Available presets
|
|
492
|
+
presets_table = Table(title="Available Model Presets", box=box.SIMPLE)
|
|
493
|
+
presets_table.add_column("Preset", style="cyan bold")
|
|
494
|
+
presets_table.add_column("Parameters", style="white")
|
|
495
|
+
presets_table.add_column("Use Case", style="dim")
|
|
496
|
+
|
|
497
|
+
presets_table.add_row("tiny", "~1M", "Quick experiments, testing")
|
|
498
|
+
presets_table.add_row("small", "~10M", "Small datasets, fast training")
|
|
499
|
+
presets_table.add_row("base", "~50M", "General purpose")
|
|
500
|
+
presets_table.add_row("large", "~100M+", "Large datasets, best quality")
|
|
501
|
+
|
|
502
|
+
console.print(presets_table)
|
|
503
|
+
console.print()
|
|
504
|
+
|
|
505
|
+
# Links
|
|
506
|
+
console.print("[dim]š Documentation:[/] [blue underline]https://github.com/langtrain-ai/langtune[/]")
|
|
507
|
+
console.print("[dim]š Report issues:[/] [blue underline]https://github.com/langtrain-ai/langtune/issues[/]")
|
|
508
|
+
console.print()
|
|
509
|
+
else:
|
|
510
|
+
print("""
|
|
511
|
+
š Quick Start Guide
|
|
512
|
+
====================
|
|
513
|
+
|
|
514
|
+
1. Prepare your data
|
|
515
|
+
Place your training text in a .txt or .json file
|
|
516
|
+
|
|
517
|
+
2. Start training
|
|
518
|
+
langtune train --preset small --train-file data.txt
|
|
519
|
+
|
|
520
|
+
3. Evaluate your model
|
|
521
|
+
langtune evaluate --config config.yaml --model-path model.pt
|
|
522
|
+
|
|
523
|
+
4. Generate text
|
|
524
|
+
langtune generate --config config.yaml --model-path model.pt --prompt "Hello"
|
|
525
|
+
|
|
526
|
+
Available Presets: tiny, small, base, large
|
|
527
|
+
|
|
528
|
+
š Docs: https://github.com/langtrain-ai/langtune
|
|
529
|
+
""")
|
|
530
|
+
|
|
531
|
+
return 0
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def _print_banner():
|
|
535
|
+
"""Print the CLI banner."""
|
|
536
|
+
if RICH_AVAILABLE:
|
|
537
|
+
banner = Text()
|
|
538
|
+
banner.append("\nāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā\n", style="cyan bold")
|
|
539
|
+
banner.append("ā", style="cyan bold")
|
|
540
|
+
banner.append(" ", style="")
|
|
541
|
+
banner.append("LANGTUNE", style="bold magenta")
|
|
542
|
+
banner.append(" ", style="")
|
|
543
|
+
banner.append("ā\n", style="cyan bold")
|
|
544
|
+
banner.append("ā", style="cyan bold")
|
|
545
|
+
banner.append(" Efficient LoRA Fine-Tuning for LLMs ", style="dim")
|
|
546
|
+
banner.append("ā\n", style="cyan bold")
|
|
547
|
+
banner.append("āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā\n", style="cyan bold")
|
|
548
|
+
console.print(banner)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def main():
|
|
552
|
+
"""Main CLI entry point."""
|
|
553
|
+
parser = argparse.ArgumentParser(
|
|
554
|
+
description='Langtune: Efficient LoRA Fine-Tuning for Text LLMs',
|
|
555
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
556
|
+
epilog="""
|
|
557
|
+
Examples:
|
|
558
|
+
langtune info # Show quick start guide
|
|
559
|
+
langtune version # Show version info
|
|
560
|
+
langtune train --preset small --train-file data.txt # Train with preset
|
|
561
|
+
langtune train --config config.yaml # Train with config
|
|
562
|
+
langtune evaluate --config config.yaml --model-path m.pt # Evaluate model
|
|
563
|
+
langtune generate --config c.yaml --model-path m.pt # Generate text
|
|
564
|
+
langtune concept --concept rlhf # Concept demo
|
|
565
|
+
|
|
566
|
+
Learn more: https://github.com/langtrain-ai/langtune
|
|
567
|
+
"""
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
parser.add_argument('-v', '--version', action='store_true', help='Show version information')
|
|
571
|
+
|
|
572
|
+
subparsers = parser.add_subparsers(dest='command', help='Available commands')
|
|
573
|
+
|
|
574
|
+
# Auth command
|
|
575
|
+
auth_parser = subparsers.add_parser('auth', help='Manage API key authentication')
|
|
576
|
+
auth_subparsers = auth_parser.add_subparsers(dest='auth_command', help='Auth commands')
|
|
577
|
+
auth_subparsers.add_parser('login', help='Login with your API key')
|
|
578
|
+
auth_subparsers.add_parser('logout', help='Remove stored API key')
|
|
579
|
+
auth_subparsers.add_parser('status', help='Show authentication status and usage')
|
|
580
|
+
|
|
581
|
+
# Version command
|
|
582
|
+
subparsers.add_parser('version', help='Show version and system information')
|
|
583
|
+
|
|
584
|
+
# Info command
|
|
585
|
+
subparsers.add_parser('info', help='Show quick start guide and documentation')
|
|
586
|
+
|
|
587
|
+
# Train command
|
|
588
|
+
train_parser = subparsers.add_parser('train', help='Train a model with LoRA')
|
|
589
|
+
train_parser.add_argument('--config', type=str, help='Path to configuration file')
|
|
590
|
+
train_parser.add_argument('--preset', type=str, choices=['tiny', 'small', 'base', 'large'],
|
|
591
|
+
help='Use a preset configuration')
|
|
592
|
+
train_parser.add_argument('--train-file', type=str, help='Path to training data file')
|
|
593
|
+
train_parser.add_argument('--eval-file', type=str, help='Path to evaluation data file')
|
|
594
|
+
train_parser.add_argument('--output-dir', type=str, help='Output directory for checkpoints')
|
|
595
|
+
train_parser.add_argument('--batch-size', type=int, help='Batch size')
|
|
596
|
+
train_parser.add_argument('--learning-rate', type=float, help='Learning rate')
|
|
597
|
+
train_parser.add_argument('--epochs', type=int, help='Number of epochs')
|
|
598
|
+
train_parser.add_argument('--resume-from', type=str, help='Resume from checkpoint')
|
|
599
|
+
|
|
600
|
+
# Optimization flags
|
|
601
|
+
train_parser.add_argument('--fast', action='store_true',
|
|
602
|
+
help='Use FastLoRALanguageModel with all optimizations (RoPE, flash attention, grad checkpointing)')
|
|
603
|
+
train_parser.add_argument('--4bit', dest='use_4bit', action='store_true',
|
|
604
|
+
help='Use 4-bit quantization (QLoRA style)')
|
|
605
|
+
train_parser.add_argument('--gradient-checkpointing', action='store_true',
|
|
606
|
+
help='Enable gradient checkpointing to reduce memory')
|
|
607
|
+
train_parser.add_argument('--mixed-precision', type=str, choices=['fp16', 'bf16', 'fp32'],
|
|
608
|
+
default='fp16', help='Mixed precision training mode')
|
|
609
|
+
train_parser.add_argument('--gradient-accumulation', type=int, default=1,
|
|
610
|
+
help='Number of gradient accumulation steps')
|
|
611
|
+
|
|
612
|
+
# Evaluate command
|
|
613
|
+
eval_parser = subparsers.add_parser('evaluate', help='Evaluate a trained model')
|
|
614
|
+
eval_parser.add_argument('--config', type=str, required=True, help='Path to configuration file')
|
|
615
|
+
eval_parser.add_argument('--model-path', type=str, required=True, help='Path to model checkpoint')
|
|
616
|
+
|
|
617
|
+
# Generate command
|
|
618
|
+
gen_parser = subparsers.add_parser('generate', help='Generate text with a trained model')
|
|
619
|
+
gen_parser.add_argument('--config', type=str, required=True, help='Path to configuration file')
|
|
620
|
+
gen_parser.add_argument('--model-path', type=str, required=True, help='Path to model checkpoint')
|
|
621
|
+
gen_parser.add_argument('--prompt', type=str, help='Text prompt for generation')
|
|
622
|
+
gen_parser.add_argument('--max-length', type=int, help='Maximum generation length')
|
|
623
|
+
gen_parser.add_argument('--temperature', type=float, default=1.0, help='Sampling temperature')
|
|
624
|
+
gen_parser.add_argument('--top-k', type=int, help='Top-k sampling')
|
|
625
|
+
gen_parser.add_argument('--top-p', type=float, help='Top-p (nucleus) sampling')
|
|
626
|
+
|
|
627
|
+
# Concept command
|
|
628
|
+
concept_parser = subparsers.add_parser('concept', help='Run a concept demonstration')
|
|
629
|
+
concept_parser.add_argument('--concept', type=str, required=True,
|
|
630
|
+
choices=['rlhf', 'cot', 'ccot', 'grpo', 'rlvr', 'dpo', 'ppo', 'lime', 'shap'],
|
|
631
|
+
help='LLM concept to demonstrate')
|
|
632
|
+
|
|
633
|
+
args = parser.parse_args()
|
|
634
|
+
|
|
635
|
+
# Handle -v/--version flag
|
|
636
|
+
if args.version:
|
|
637
|
+
return version_command(args)
|
|
638
|
+
|
|
639
|
+
if not args.command:
|
|
640
|
+
_print_banner()
|
|
641
|
+
parser.print_help()
|
|
642
|
+
if RICH_AVAILABLE:
|
|
643
|
+
console.print("\n[dim]š” Tip: Run[/] [cyan]langtune info[/] [dim]for a quick start guide[/]\n")
|
|
644
|
+
return 1
|
|
645
|
+
|
|
646
|
+
# Route to appropriate command handler
|
|
647
|
+
if args.command == 'auth':
|
|
648
|
+
if not args.auth_command:
|
|
649
|
+
# Show auth help
|
|
650
|
+
if RICH_AVAILABLE:
|
|
651
|
+
console.print("\n[bold cyan]š Authentication Commands[/]\n")
|
|
652
|
+
console.print(" [cyan]langtune auth login[/] - Login with your API key")
|
|
653
|
+
console.print(" [cyan]langtune auth logout[/] - Remove stored API key")
|
|
654
|
+
console.print(" [cyan]langtune auth status[/] - Show auth status and usage\n")
|
|
655
|
+
console.print("[dim]Get your API key at:[/] [blue underline]https://app.langtrain.xyz[/]\n")
|
|
656
|
+
else:
|
|
657
|
+
print("\nAuthentication Commands:\n")
|
|
658
|
+
print(" langtune auth login - Login with your API key")
|
|
659
|
+
print(" langtune auth logout - Remove stored API key")
|
|
660
|
+
print(" langtune auth status - Show auth status and usage\n")
|
|
661
|
+
return 0
|
|
662
|
+
elif args.auth_command == 'login':
|
|
663
|
+
return 0 if interactive_login() else 1
|
|
664
|
+
elif args.auth_command == 'logout':
|
|
665
|
+
logout()
|
|
666
|
+
return 0
|
|
667
|
+
elif args.auth_command == 'status':
|
|
668
|
+
print_usage_info()
|
|
669
|
+
return 0
|
|
670
|
+
elif args.command == 'version':
|
|
671
|
+
return version_command(args)
|
|
672
|
+
elif args.command == 'info':
|
|
673
|
+
return info_command(args)
|
|
674
|
+
elif args.command == 'train':
|
|
675
|
+
return train_command(args)
|
|
676
|
+
elif args.command == 'evaluate':
|
|
677
|
+
return evaluate_command(args)
|
|
678
|
+
elif args.command == 'generate':
|
|
679
|
+
return generate_command(args)
|
|
680
|
+
elif args.command == 'concept':
|
|
681
|
+
return concept_command(args)
|
|
682
|
+
else:
|
|
683
|
+
parser.print_help()
|
|
684
|
+
return 1
|
|
685
|
+
|
|
686
|
+
if __name__ == '__main__':
|
|
687
|
+
sys.exit(main())
|