langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/cli.py ADDED
@@ -0,0 +1,687 @@
1
+ """
2
+ cli.py: Command-line interface for Langtune
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ import sys
8
+ import logging
9
+ import torch
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ from .config import Config, load_config, save_config, get_preset_config, validate_config
14
+ from .trainer import create_trainer
15
+ from .data import load_dataset_from_config, create_data_loader, DataCollator
16
+ from .models import LoRALanguageModel
17
+ from .auth import (
18
+ get_api_key, verify_api_key, check_usage, interactive_login, logout,
19
+ print_usage_info, AuthenticationError, UsageLimitError, require_auth
20
+ )
21
+
22
+ # Try to import rich for beautiful output
23
+ try:
24
+ from rich.console import Console
25
+ from rich.panel import Panel
26
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
27
+ from rich.table import Table
28
+ from rich import box
29
+ from rich.text import Text
30
+ RICH_AVAILABLE = True
31
+ console = Console()
32
+ except ImportError:
33
+ RICH_AVAILABLE = False
34
+ console = None
35
+
36
+ # Version
37
+ __version__ = "0.1.2"
38
+
39
+ # Setup logging
40
+ logging.basicConfig(
41
+ level=logging.INFO,
42
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
43
+ )
44
+ logger = logging.getLogger(__name__)
45
+
46
+ def _check_auth():
47
+ """Check authentication before running protected commands."""
48
+ api_key = get_api_key()
49
+
50
+ if not api_key:
51
+ if RICH_AVAILABLE:
52
+ console.print("\n[bold red]šŸ” Authentication Required[/]\n")
53
+ console.print("Langtune requires an API key to run. Get your free key at:")
54
+ console.print("[blue underline]https://app.langtrain.xyz[/]\n")
55
+ console.print("Then authenticate with: [cyan]langtune auth login[/]\n")
56
+ else:
57
+ print("\nšŸ” Authentication Required\n")
58
+ print("Get your API key at: https://app.langtrain.xyz")
59
+ print("Then run: langtune auth login\n")
60
+ return False
61
+
62
+ try:
63
+ usage = check_usage(api_key)
64
+ if RICH_AVAILABLE:
65
+ remaining = f"{usage['tokens_remaining']:,}"
66
+ console.print(f"[dim]Tokens remaining: {remaining}[/]")
67
+ return True
68
+ except AuthenticationError as e:
69
+ if RICH_AVAILABLE:
70
+ console.print(f"[red]āŒ {e}[/]")
71
+ else:
72
+ print(f"āŒ {e}")
73
+ return False
74
+ except UsageLimitError as e:
75
+ if RICH_AVAILABLE:
76
+ console.print(f"[yellow]āš ļø {e}[/]")
77
+ else:
78
+ print(f"āš ļø {e}")
79
+ return False
80
+
81
+
82
+ def train_command(args):
83
+ """Handle the train command."""
84
+ # Check authentication first
85
+ if not _check_auth():
86
+ return 1
87
+
88
+ logger.info("Starting training...")
89
+
90
+ # Load configuration
91
+ if args.config:
92
+ config = load_config(args.config)
93
+ elif args.preset:
94
+ config = get_preset_config(args.preset)
95
+ else:
96
+ logger.error("Either --config or --preset must be specified")
97
+ return 1
98
+
99
+ # Override config with command line arguments
100
+ if args.train_file:
101
+ config.data.train_file = args.train_file
102
+ if args.eval_file:
103
+ config.data.eval_file = args.eval_file
104
+ if args.output_dir:
105
+ config.output_dir = args.output_dir
106
+ if args.batch_size:
107
+ config.training.batch_size = args.batch_size
108
+ if args.learning_rate:
109
+ config.training.learning_rate = args.learning_rate
110
+ if args.epochs:
111
+ config.training.num_epochs = args.epochs
112
+
113
+ # Validate configuration
114
+ try:
115
+ validate_config(config)
116
+ except ValueError as e:
117
+ logger.error(f"Configuration validation failed: {e}")
118
+ return 1
119
+
120
+ # Create output directory
121
+ os.makedirs(config.output_dir, exist_ok=True)
122
+
123
+ # Save configuration
124
+ config_path = os.path.join(config.output_dir, "config.yaml")
125
+ save_config(config, config_path)
126
+ logger.info(f"Configuration saved to {config_path}")
127
+
128
+ # Load datasets
129
+ try:
130
+ train_dataset, val_dataset, test_dataset = load_dataset_from_config(config)
131
+ logger.info(f"Loaded datasets: {len(train_dataset)} train, {len(val_dataset)} val, {len(test_dataset)} test")
132
+ except Exception as e:
133
+ logger.error(f"Failed to load datasets: {e}")
134
+ return 1
135
+
136
+ # Create data loaders
137
+ collate_fn = DataCollator()
138
+
139
+ train_dataloader = create_data_loader(
140
+ train_dataset,
141
+ batch_size=config.training.batch_size,
142
+ shuffle=True,
143
+ num_workers=config.num_workers,
144
+ pin_memory=config.pin_memory,
145
+ collate_fn=collate_fn
146
+ )
147
+
148
+ val_dataloader = create_data_loader(
149
+ val_dataset,
150
+ batch_size=config.training.batch_size,
151
+ shuffle=False,
152
+ num_workers=config.num_workers,
153
+ pin_memory=config.pin_memory,
154
+ collate_fn=collate_fn
155
+ ) if val_dataset else None
156
+
157
+ test_dataloader = create_data_loader(
158
+ test_dataset,
159
+ batch_size=config.training.batch_size,
160
+ shuffle=False,
161
+ num_workers=config.num_workers,
162
+ pin_memory=config.pin_memory,
163
+ collate_fn=collate_fn
164
+ ) if test_dataset else None
165
+
166
+ # Create trainer
167
+ try:
168
+ trainer = create_trainer(
169
+ config=config,
170
+ train_dataloader=train_dataloader,
171
+ val_dataloader=val_dataloader,
172
+ test_dataloader=test_dataloader
173
+ )
174
+ except Exception as e:
175
+ logger.error(f"Failed to create trainer: {e}")
176
+ return 1
177
+
178
+ # Start training
179
+ try:
180
+ trainer.train(resume_from_checkpoint=args.resume_from)
181
+ logger.info("Training completed successfully!")
182
+ return 0
183
+ except Exception as e:
184
+ logger.error(f"Training failed: {e}")
185
+ return 1
186
+
187
+ def evaluate_command(args):
188
+ """Handle the evaluate command."""
189
+ # Check authentication first
190
+ if not _check_auth():
191
+ return 1
192
+
193
+ logger.info("Starting evaluation...")
194
+
195
+ if not args.model_path:
196
+ logger.error("--model_path is required for evaluation")
197
+ return 1
198
+
199
+ # Load configuration
200
+ if args.config:
201
+ config = load_config(args.config)
202
+ else:
203
+ logger.error("--config is required for evaluation")
204
+ return 1
205
+
206
+ # Load model
207
+ try:
208
+ model = LoRALanguageModel(
209
+ vocab_size=config.model.vocab_size,
210
+ embed_dim=config.model.embed_dim,
211
+ num_layers=config.model.num_layers,
212
+ num_heads=config.model.num_heads,
213
+ max_seq_len=config.model.max_seq_len,
214
+ mlp_ratio=config.model.mlp_ratio,
215
+ dropout=config.model.dropout,
216
+ lora_config=config.model.lora.__dict__ if config.model.lora else None
217
+ )
218
+
219
+ checkpoint = torch.load(args.model_path, map_location="cpu")
220
+ model.load_state_dict(checkpoint["model_state_dict"])
221
+ model.eval()
222
+
223
+ logger.info(f"Model loaded from {args.model_path}")
224
+ except Exception as e:
225
+ logger.error(f"Failed to load model: {e}")
226
+ return 1
227
+
228
+ # Load test dataset
229
+ try:
230
+ _, _, test_dataset = load_dataset_from_config(config)
231
+ test_dataloader = create_data_loader(
232
+ test_dataset,
233
+ batch_size=config.training.batch_size,
234
+ shuffle=False,
235
+ num_workers=config.num_workers,
236
+ pin_memory=config.pin_memory,
237
+ collate_fn=DataCollator()
238
+ )
239
+ except Exception as e:
240
+ logger.error(f"Failed to load test dataset: {e}")
241
+ return 1
242
+
243
+ # Evaluate
244
+ try:
245
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
246
+ model.to(device)
247
+
248
+ total_loss = 0.0
249
+ num_batches = 0
250
+
251
+ with torch.no_grad():
252
+ for batch in test_dataloader:
253
+ batch = {k: v.to(device) for k, v in batch.items()}
254
+ outputs = model(**batch)
255
+ total_loss += outputs["loss"].item()
256
+ num_batches += 1
257
+
258
+ avg_loss = total_loss / num_batches
259
+ logger.info(f"Test loss: {avg_loss:.4f}")
260
+
261
+ return 0
262
+ except Exception as e:
263
+ logger.error(f"Evaluation failed: {e}")
264
+ return 1
265
+
266
+ def generate_command(args):
267
+ """Handle the generate command."""
268
+ # Check authentication first
269
+ if not _check_auth():
270
+ return 1
271
+
272
+ logger.info("Starting text generation...")
273
+
274
+ if not args.model_path:
275
+ logger.error("--model_path is required for generation")
276
+ return 1
277
+
278
+ # Load configuration
279
+ if args.config:
280
+ config = load_config(args.config)
281
+ else:
282
+ logger.error("--config is required for generation")
283
+ return 1
284
+
285
+ # Load model
286
+ try:
287
+ model = LoRALanguageModel(
288
+ vocab_size=config.model.vocab_size,
289
+ embed_dim=config.model.embed_dim,
290
+ num_layers=config.model.num_layers,
291
+ num_heads=config.model.num_heads,
292
+ max_seq_len=config.model.max_seq_len,
293
+ mlp_ratio=config.model.mlp_ratio,
294
+ dropout=config.model.dropout,
295
+ lora_config=config.model.lora.__dict__ if config.model.lora else None
296
+ )
297
+
298
+ checkpoint = torch.load(args.model_path, map_location="cpu")
299
+ model.load_state_dict(checkpoint["model_state_dict"])
300
+ model.eval()
301
+
302
+ logger.info(f"Model loaded from {args.model_path}")
303
+ except Exception as e:
304
+ logger.error(f"Failed to load model: {e}")
305
+ return 1
306
+
307
+ # Generate text
308
+ try:
309
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
310
+ model.to(device)
311
+
312
+ prompt = args.prompt or "The quick brown fox"
313
+ max_length = args.max_length or 100
314
+
315
+ # Simple tokenization
316
+ input_ids = torch.tensor([ord(c) for c in prompt[:50]], dtype=torch.long).unsqueeze(0).to(device)
317
+
318
+ with torch.no_grad():
319
+ generated = model.generate(
320
+ input_ids,
321
+ max_length=max_length,
322
+ temperature=args.temperature,
323
+ top_k=args.top_k,
324
+ top_p=args.top_p
325
+ )
326
+
327
+ # Simple decoding
328
+ generated_text = "".join([chr(i) for i in generated[0].cpu().tolist()])
329
+
330
+ print(f"Prompt: {prompt}")
331
+ print(f"Generated: {generated_text}")
332
+
333
+ return 0
334
+ except Exception as e:
335
+ logger.error(f"Generation failed: {e}")
336
+ return 1
337
+
338
+ def concept_command(args):
339
+ """Handle the concept command."""
340
+ concept_name = args.concept.upper()
341
+
342
+ if RICH_AVAILABLE:
343
+ console.print(f"\n[bold cyan]🧪 Running concept demonstration:[/] [bold magenta]{concept_name}[/]\n")
344
+ else:
345
+ logger.info(f"Running concept demonstration: {concept_name}")
346
+
347
+ # Simulate concept execution with rich progress
348
+ import time
349
+
350
+ if RICH_AVAILABLE:
351
+ with Progress(
352
+ SpinnerColumn(),
353
+ TextColumn("[progress.description]{task.description}"),
354
+ BarColumn(),
355
+ TaskProgressColumn(),
356
+ console=console
357
+ ) as progress:
358
+ task = progress.add_task(f"[cyan]Processing {concept_name}...", total=100)
359
+ for i in range(100):
360
+ time.sleep(0.02)
361
+ progress.update(task, advance=1)
362
+
363
+ console.print(f"\n[bold green]āœ“[/] {concept_name} demonstration completed!\n")
364
+ else:
365
+ from tqdm import tqdm
366
+ for i in tqdm(range(100), desc=f"Progress for {concept_name}"):
367
+ time.sleep(0.02)
368
+ logger.info(f"{concept_name} demonstration completed!")
369
+
370
+ return 0
371
+
372
+
373
+ def _check_tpu() -> bool:
374
+ """Check if Google TPU is available via torch_xla."""
375
+ try:
376
+ import torch_xla
377
+ import torch_xla.core.xla_model as xm
378
+ device = xm.xla_device()
379
+ return "TPU" in str(device) or "xla" in str(device).lower()
380
+ except:
381
+ return False
382
+
383
+
384
+ def _get_tpu_info() -> str:
385
+ """Get TPU information if available."""
386
+ try:
387
+ import os
388
+ import torch_xla.core.xla_model as xm
389
+ tpu_name = os.environ.get("TPU_NAME", "")
390
+ tpu_cores = xm.xrt_world_size()
391
+
392
+ # Detect version
393
+ if "v4" in tpu_name.lower():
394
+ version = "v4"
395
+ elif "v3" in tpu_name.lower():
396
+ version = "v3"
397
+ elif "v2" in tpu_name.lower():
398
+ version = "v2"
399
+ else:
400
+ version = ""
401
+
402
+ return f"{version} ({tpu_cores} cores)"
403
+ except:
404
+ return "(available)"
405
+
406
+
407
+ def version_command(args):
408
+ """Handle the version command."""
409
+ if RICH_AVAILABLE:
410
+ # Check accelerator availability with detailed info
411
+ accelerator_type = "None"
412
+
413
+ # Check for NVIDIA CUDA
414
+ if torch.cuda.is_available():
415
+ gpu_name = torch.cuda.get_device_name(0)
416
+ gpu_count = torch.cuda.device_count()
417
+
418
+ try:
419
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
420
+ cuda_version = torch.version.cuda
421
+
422
+ if gpu_count > 1:
423
+ gpu_info = f"[green]āœ“ NVIDIA {gpu_name} Ɨ {gpu_count} ({gpu_memory:.0f}GB each)[/]"
424
+ else:
425
+ gpu_info = f"[green]āœ“ NVIDIA {gpu_name} ({gpu_memory:.0f}GB)[/]"
426
+
427
+ accelerator_type = f"CUDA {cuda_version}"
428
+ except:
429
+ gpu_info = f"[green]āœ“ NVIDIA {gpu_name}[/]"
430
+ accelerator_type = "CUDA"
431
+ # Check for Google TPU
432
+ elif _check_tpu():
433
+ tpu_info = _get_tpu_info()
434
+ gpu_info = f"[green]āœ“ Google TPU {tpu_info}[/]"
435
+ accelerator_type = "TPU (torch_xla)"
436
+ # Check for Apple MPS
437
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
438
+ gpu_info = "[green]āœ“ Apple Metal Performance Shaders (MPS)[/]"
439
+ accelerator_type = "Metal"
440
+ else:
441
+ gpu_info = "[yellow]ā—‹ Not available (CPU mode)[/]"
442
+ accelerator_type = "CPU"
443
+
444
+ table = Table(title="Langtune System Info", box=box.ROUNDED, title_style="bold magenta")
445
+ table.add_column("Component", style="cyan", no_wrap=True)
446
+ table.add_column("Value", style="white")
447
+
448
+ table.add_row("Langtune Version", f"v{__version__}")
449
+ table.add_row("Python Version", f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
450
+ table.add_row("PyTorch Version", torch.__version__)
451
+ table.add_row("Accelerator", gpu_info)
452
+ table.add_row("Backend", accelerator_type)
453
+
454
+ console.print()
455
+ console.print(table)
456
+ console.print()
457
+ else:
458
+ print(f"Langtune v{__version__}")
459
+ print(f"Python {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
460
+ print(f"PyTorch {torch.__version__}")
461
+ print(f"CUDA: {'Available' if torch.cuda.is_available() else 'Not available'}")
462
+ print(f"TPU: {'Available' if _check_tpu() else 'Not available'}")
463
+
464
+ return 0
465
+
466
+
467
+ def info_command(args):
468
+ """Handle the info command - show quick start guide."""
469
+ if RICH_AVAILABLE:
470
+ console.print()
471
+
472
+ # Quick start panel
473
+ quick_start = Text()
474
+ quick_start.append("1. Prepare your data\n", style="bold cyan")
475
+ quick_start.append(" Place your training text in a .txt or .json file\n\n")
476
+ quick_start.append("2. Start training\n", style="bold cyan")
477
+ quick_start.append(" langtune train --preset small --train-file data.txt\n\n", style="green")
478
+ quick_start.append("3. Evaluate your model\n", style="bold cyan")
479
+ quick_start.append(" langtune evaluate --config config.yaml --model-path model.pt\n\n", style="green")
480
+ quick_start.append("4. Generate text\n", style="bold cyan")
481
+ quick_start.append(" langtune generate --config config.yaml --model-path model.pt --prompt \"Hello\"\n", style="green")
482
+
483
+ panel = Panel(
484
+ quick_start,
485
+ title="[bold]šŸš€ Quick Start Guide[/]",
486
+ border_style="cyan",
487
+ box=box.ROUNDED
488
+ )
489
+ console.print(panel)
490
+
491
+ # Available presets
492
+ presets_table = Table(title="Available Model Presets", box=box.SIMPLE)
493
+ presets_table.add_column("Preset", style="cyan bold")
494
+ presets_table.add_column("Parameters", style="white")
495
+ presets_table.add_column("Use Case", style="dim")
496
+
497
+ presets_table.add_row("tiny", "~1M", "Quick experiments, testing")
498
+ presets_table.add_row("small", "~10M", "Small datasets, fast training")
499
+ presets_table.add_row("base", "~50M", "General purpose")
500
+ presets_table.add_row("large", "~100M+", "Large datasets, best quality")
501
+
502
+ console.print(presets_table)
503
+ console.print()
504
+
505
+ # Links
506
+ console.print("[dim]šŸ“š Documentation:[/] [blue underline]https://github.com/langtrain-ai/langtune[/]")
507
+ console.print("[dim]šŸ› Report issues:[/] [blue underline]https://github.com/langtrain-ai/langtune/issues[/]")
508
+ console.print()
509
+ else:
510
+ print("""
511
+ šŸš€ Quick Start Guide
512
+ ====================
513
+
514
+ 1. Prepare your data
515
+ Place your training text in a .txt or .json file
516
+
517
+ 2. Start training
518
+ langtune train --preset small --train-file data.txt
519
+
520
+ 3. Evaluate your model
521
+ langtune evaluate --config config.yaml --model-path model.pt
522
+
523
+ 4. Generate text
524
+ langtune generate --config config.yaml --model-path model.pt --prompt "Hello"
525
+
526
+ Available Presets: tiny, small, base, large
527
+
528
+ šŸ“š Docs: https://github.com/langtrain-ai/langtune
529
+ """)
530
+
531
+ return 0
532
+
533
+
534
+ def _print_banner():
535
+ """Print the CLI banner."""
536
+ if RICH_AVAILABLE:
537
+ banner = Text()
538
+ banner.append("\n╔═══════════════════════════════════════════════════════════╗\n", style="cyan bold")
539
+ banner.append("ā•‘", style="cyan bold")
540
+ banner.append(" ", style="")
541
+ banner.append("LANGTUNE", style="bold magenta")
542
+ banner.append(" ", style="")
543
+ banner.append("ā•‘\n", style="cyan bold")
544
+ banner.append("ā•‘", style="cyan bold")
545
+ banner.append(" Efficient LoRA Fine-Tuning for LLMs ", style="dim")
546
+ banner.append("ā•‘\n", style="cyan bold")
547
+ banner.append("ā•šā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•\n", style="cyan bold")
548
+ console.print(banner)
549
+
550
+
551
+ def main():
552
+ """Main CLI entry point."""
553
+ parser = argparse.ArgumentParser(
554
+ description='Langtune: Efficient LoRA Fine-Tuning for Text LLMs',
555
+ formatter_class=argparse.RawDescriptionHelpFormatter,
556
+ epilog="""
557
+ Examples:
558
+ langtune info # Show quick start guide
559
+ langtune version # Show version info
560
+ langtune train --preset small --train-file data.txt # Train with preset
561
+ langtune train --config config.yaml # Train with config
562
+ langtune evaluate --config config.yaml --model-path m.pt # Evaluate model
563
+ langtune generate --config c.yaml --model-path m.pt # Generate text
564
+ langtune concept --concept rlhf # Concept demo
565
+
566
+ Learn more: https://github.com/langtrain-ai/langtune
567
+ """
568
+ )
569
+
570
+ parser.add_argument('-v', '--version', action='store_true', help='Show version information')
571
+
572
+ subparsers = parser.add_subparsers(dest='command', help='Available commands')
573
+
574
+ # Auth command
575
+ auth_parser = subparsers.add_parser('auth', help='Manage API key authentication')
576
+ auth_subparsers = auth_parser.add_subparsers(dest='auth_command', help='Auth commands')
577
+ auth_subparsers.add_parser('login', help='Login with your API key')
578
+ auth_subparsers.add_parser('logout', help='Remove stored API key')
579
+ auth_subparsers.add_parser('status', help='Show authentication status and usage')
580
+
581
+ # Version command
582
+ subparsers.add_parser('version', help='Show version and system information')
583
+
584
+ # Info command
585
+ subparsers.add_parser('info', help='Show quick start guide and documentation')
586
+
587
+ # Train command
588
+ train_parser = subparsers.add_parser('train', help='Train a model with LoRA')
589
+ train_parser.add_argument('--config', type=str, help='Path to configuration file')
590
+ train_parser.add_argument('--preset', type=str, choices=['tiny', 'small', 'base', 'large'],
591
+ help='Use a preset configuration')
592
+ train_parser.add_argument('--train-file', type=str, help='Path to training data file')
593
+ train_parser.add_argument('--eval-file', type=str, help='Path to evaluation data file')
594
+ train_parser.add_argument('--output-dir', type=str, help='Output directory for checkpoints')
595
+ train_parser.add_argument('--batch-size', type=int, help='Batch size')
596
+ train_parser.add_argument('--learning-rate', type=float, help='Learning rate')
597
+ train_parser.add_argument('--epochs', type=int, help='Number of epochs')
598
+ train_parser.add_argument('--resume-from', type=str, help='Resume from checkpoint')
599
+
600
+ # Optimization flags
601
+ train_parser.add_argument('--fast', action='store_true',
602
+ help='Use FastLoRALanguageModel with all optimizations (RoPE, flash attention, grad checkpointing)')
603
+ train_parser.add_argument('--4bit', dest='use_4bit', action='store_true',
604
+ help='Use 4-bit quantization (QLoRA style)')
605
+ train_parser.add_argument('--gradient-checkpointing', action='store_true',
606
+ help='Enable gradient checkpointing to reduce memory')
607
+ train_parser.add_argument('--mixed-precision', type=str, choices=['fp16', 'bf16', 'fp32'],
608
+ default='fp16', help='Mixed precision training mode')
609
+ train_parser.add_argument('--gradient-accumulation', type=int, default=1,
610
+ help='Number of gradient accumulation steps')
611
+
612
+ # Evaluate command
613
+ eval_parser = subparsers.add_parser('evaluate', help='Evaluate a trained model')
614
+ eval_parser.add_argument('--config', type=str, required=True, help='Path to configuration file')
615
+ eval_parser.add_argument('--model-path', type=str, required=True, help='Path to model checkpoint')
616
+
617
+ # Generate command
618
+ gen_parser = subparsers.add_parser('generate', help='Generate text with a trained model')
619
+ gen_parser.add_argument('--config', type=str, required=True, help='Path to configuration file')
620
+ gen_parser.add_argument('--model-path', type=str, required=True, help='Path to model checkpoint')
621
+ gen_parser.add_argument('--prompt', type=str, help='Text prompt for generation')
622
+ gen_parser.add_argument('--max-length', type=int, help='Maximum generation length')
623
+ gen_parser.add_argument('--temperature', type=float, default=1.0, help='Sampling temperature')
624
+ gen_parser.add_argument('--top-k', type=int, help='Top-k sampling')
625
+ gen_parser.add_argument('--top-p', type=float, help='Top-p (nucleus) sampling')
626
+
627
+ # Concept command
628
+ concept_parser = subparsers.add_parser('concept', help='Run a concept demonstration')
629
+ concept_parser.add_argument('--concept', type=str, required=True,
630
+ choices=['rlhf', 'cot', 'ccot', 'grpo', 'rlvr', 'dpo', 'ppo', 'lime', 'shap'],
631
+ help='LLM concept to demonstrate')
632
+
633
+ args = parser.parse_args()
634
+
635
+ # Handle -v/--version flag
636
+ if args.version:
637
+ return version_command(args)
638
+
639
+ if not args.command:
640
+ _print_banner()
641
+ parser.print_help()
642
+ if RICH_AVAILABLE:
643
+ console.print("\n[dim]šŸ’” Tip: Run[/] [cyan]langtune info[/] [dim]for a quick start guide[/]\n")
644
+ return 1
645
+
646
+ # Route to appropriate command handler
647
+ if args.command == 'auth':
648
+ if not args.auth_command:
649
+ # Show auth help
650
+ if RICH_AVAILABLE:
651
+ console.print("\n[bold cyan]šŸ” Authentication Commands[/]\n")
652
+ console.print(" [cyan]langtune auth login[/] - Login with your API key")
653
+ console.print(" [cyan]langtune auth logout[/] - Remove stored API key")
654
+ console.print(" [cyan]langtune auth status[/] - Show auth status and usage\n")
655
+ console.print("[dim]Get your API key at:[/] [blue underline]https://app.langtrain.xyz[/]\n")
656
+ else:
657
+ print("\nAuthentication Commands:\n")
658
+ print(" langtune auth login - Login with your API key")
659
+ print(" langtune auth logout - Remove stored API key")
660
+ print(" langtune auth status - Show auth status and usage\n")
661
+ return 0
662
+ elif args.auth_command == 'login':
663
+ return 0 if interactive_login() else 1
664
+ elif args.auth_command == 'logout':
665
+ logout()
666
+ return 0
667
+ elif args.auth_command == 'status':
668
+ print_usage_info()
669
+ return 0
670
+ elif args.command == 'version':
671
+ return version_command(args)
672
+ elif args.command == 'info':
673
+ return info_command(args)
674
+ elif args.command == 'train':
675
+ return train_command(args)
676
+ elif args.command == 'evaluate':
677
+ return evaluate_command(args)
678
+ elif args.command == 'generate':
679
+ return generate_command(args)
680
+ elif args.command == 'concept':
681
+ return concept_command(args)
682
+ else:
683
+ parser.print_help()
684
+ return 1
685
+
686
+ if __name__ == '__main__':
687
+ sys.exit(main())