langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/__init__.py ADDED
@@ -0,0 +1,315 @@
1
+ """
2
+ Langtune: Efficient LoRA Fine-Tuning for Text LLMs
3
+
4
+ This package provides tools and modules for efficient fine-tuning of large language models (LLMs) on text data using Low-Rank Adaptation (LoRA).
5
+ """
6
+
7
+ import os
8
+ import sys
9
+
10
+ __version__ = "0.1.2"
11
+
12
+ # Banner display control
13
+ _BANNER_SHOWN = False
14
+ _SHOW_BANNER = os.environ.get("LANGTUNE_NO_BANNER", "0") != "1"
15
+
16
+
17
+ def _check_tpu_available() -> bool:
18
+ """Check if Google TPU is available via torch_xla."""
19
+ try:
20
+ import torch_xla
21
+ import torch_xla.core.xla_model as xm
22
+ # Try to get a TPU device
23
+ device = xm.xla_device()
24
+ return "TPU" in str(device) or "xla" in str(device).lower()
25
+ except (ImportError, RuntimeError, Exception):
26
+ return False
27
+
28
+
29
+ def _get_tpu_version() -> str:
30
+ """Get TPU version if available."""
31
+ try:
32
+ import torch_xla
33
+ # Try to detect TPU version from environment or device info
34
+ import os
35
+ tpu_name = os.environ.get("TPU_NAME", "")
36
+ if "v4" in tpu_name.lower():
37
+ return "4"
38
+ elif "v3" in tpu_name.lower():
39
+ return "3"
40
+ elif "v2" in tpu_name.lower():
41
+ return "2"
42
+ return "4" # Default to v4 for newer TPUs
43
+ except:
44
+ return "?"
45
+
46
+
47
+ def _show_welcome_banner():
48
+ """Display a beautiful welcome banner on first import."""
49
+ global _BANNER_SHOWN
50
+
51
+ if _BANNER_SHOWN or not _SHOW_BANNER:
52
+ return
53
+
54
+ _BANNER_SHOWN = True
55
+
56
+ try:
57
+ from rich.console import Console
58
+ from rich.panel import Panel
59
+ from rich.text import Text
60
+ from rich import box
61
+ import torch
62
+
63
+ console = Console()
64
+
65
+ # Check GPU/TPU availability with detailed info
66
+ # Check for NVIDIA CUDA
67
+ if torch.cuda.is_available():
68
+ gpu_name = torch.cuda.get_device_name(0)
69
+ gpu_count = torch.cuda.device_count()
70
+
71
+ # Get NVIDIA-specific details
72
+ try:
73
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
74
+ cuda_version = torch.version.cuda
75
+
76
+ if gpu_count > 1:
77
+ gpu_info = f"✓ NVIDIA: {gpu_name} × {gpu_count} ({gpu_memory:.0f}GB each, CUDA {cuda_version})"
78
+ else:
79
+ gpu_info = f"✓ NVIDIA: {gpu_name} ({gpu_memory:.0f}GB, CUDA {cuda_version})"
80
+ except:
81
+ gpu_info = f"✓ NVIDIA: {gpu_name}"
82
+
83
+ gpu_style = "green"
84
+ # Check for Google TPU (via torch_xla)
85
+ elif _check_tpu_available():
86
+ try:
87
+ import torch_xla.core.xla_model as xm
88
+ tpu_count = xm.xrt_world_size()
89
+ gpu_info = f"✓ TPU: Google Cloud TPU v{_get_tpu_version()} ({tpu_count} cores)"
90
+ except:
91
+ gpu_info = "✓ TPU: Google Cloud TPU"
92
+ gpu_style = "green"
93
+ # Check for Apple MPS
94
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
95
+ gpu_info = "✓ Apple: Metal Performance Shaders (MPS)"
96
+ gpu_style = "green"
97
+ else:
98
+ gpu_info = "○ Accelerator: Not available (CPU mode)"
99
+ gpu_style = "yellow"
100
+
101
+ # Create banner content
102
+ banner_text = Text()
103
+
104
+ # Logo/Title
105
+ banner_text.append("╔═══════════════════════════════════════════════════════════╗\n", style="cyan bold")
106
+ banner_text.append("║", style="cyan bold")
107
+ banner_text.append(" ", style="")
108
+ banner_text.append("LANGTUNE", style="bold magenta")
109
+ banner_text.append(" ", style="")
110
+ banner_text.append("║\n", style="cyan bold")
111
+ banner_text.append("║", style="cyan bold")
112
+ banner_text.append(" Efficient LoRA Fine-Tuning for LLMs ", style="dim")
113
+ banner_text.append("║\n", style="cyan bold")
114
+ banner_text.append("╚═══════════════════════════════════════════════════════════╝\n", style="cyan bold")
115
+
116
+ # Info section
117
+ banner_text.append("\n")
118
+ banner_text.append(" 📦 Version: ", style="dim")
119
+ banner_text.append(f"v{__version__}\n", style="cyan")
120
+ banner_text.append(f" 🖥️ {gpu_info}\n", style=gpu_style)
121
+ banner_text.append(" 📚 Docs: ", style="dim")
122
+ banner_text.append("https://github.com/langtrain-ai/langtune\n", style="blue underline")
123
+
124
+ # Quick start
125
+ banner_text.append("\n 🚀 ", style="")
126
+ banner_text.append("Quick Start:\n", style="bold")
127
+ banner_text.append(" 1. langtune auth login ", style="cyan")
128
+ banner_text.append("# Get key at langtrain.xyz\n", style="dim")
129
+ banner_text.append(" 2. langtune train --preset small --train-file data.txt\n", style="cyan")
130
+
131
+ # Tips
132
+ banner_text.append("\n 💡 ", style="")
133
+ banner_text.append("Tip: ", style="yellow bold")
134
+ banner_text.append("Set LANGTUNE_NO_BANNER=1 to disable this message\n", style="dim")
135
+
136
+ console.print(banner_text)
137
+
138
+ except ImportError:
139
+ # Fallback to simple banner if rich is not available
140
+ print(f"""
141
+ ╔═══════════════════════════════════════════════════════════╗
142
+ ║ LANGTUNE ║
143
+ ║ Efficient LoRA Fine-Tuning for LLMs ║
144
+ ╚═══════════════════════════════════════════════════════════╝
145
+
146
+ 📦 Version: v{__version__}
147
+ 📚 Docs: https://langtrain.xyz
148
+
149
+ 🚀 Quick Start:
150
+ 1. langtune auth login # Get key at langtrain.xyz
151
+ 2. langtune train --preset small --train-file data.txt
152
+
153
+ 💡 Tip: Set LANGTUNE_NO_BANNER=1 to disable this message
154
+ """)
155
+
156
+
157
+ # Show banner on import (unless in non-interactive mode)
158
+ if sys.stdout.isatty():
159
+ _show_welcome_banner()
160
+
161
+
162
+ # Core models
163
+ from .models import (
164
+ LoRALanguageModel, LoRALinear, MultiHeadAttention, TransformerBlock,
165
+ FastLoRALanguageModel, FastMultiHeadAttention, FastTransformerBlock
166
+ )
167
+
168
+ # Optimizations
169
+ from .optimizations import (
170
+ OptimizationConfig, QuantizedLinear, LoRALinear4bit,
171
+ RotaryPositionEmbedding, MemoryEfficientAttention,
172
+ fused_cross_entropy, checkpoint, MixedPrecisionTrainer,
173
+ get_memory_stats, cleanup_memory
174
+ )
175
+
176
+ # Configuration
177
+ from .config import (
178
+ Config, ModelConfig, TrainingConfig, DataConfig, LoRAConfig,
179
+ default_config, load_config, save_config, get_preset_config, validate_config
180
+ )
181
+
182
+ # Data handling
183
+ from .data import (
184
+ TextDataset, LanguageModelingDataset, DataCollator,
185
+ load_text_file, load_json_file, create_data_loader, split_dataset,
186
+ SimpleTokenizer, create_sample_dataset, load_dataset_from_config
187
+ )
188
+
189
+ # Training
190
+ from .trainer import (
191
+ Trainer, FastTrainer, EarlyStopping, MetricsTracker, ModelCheckpoint,
192
+ create_trainer, create_fast_trainer
193
+ )
194
+
195
+ # Fine-tuning (best-practice API)
196
+ from .finetune import (
197
+ finetune as local_finetune, finetune_from_config, FineTuneConfig
198
+ )
199
+
200
+ # Client SDK
201
+ from .client import LangtuneClient, FineTuneJob, JobStatus, Model, APIError, get_client
202
+
203
+ # High-level API (server + local)
204
+ from .api import finetune, generate, chat, list_models, list_jobs, get_job, cancel_job
205
+
206
+ # Callbacks
207
+ from .callbacks import (
208
+ Callback, CallbackList, ProgressCallback, LearningRateMonitorCallback,
209
+ GradientMonitorCallback, ModelSizeCallback, TimerCallback, SaveHistoryCallback,
210
+ MemoryMonitorCallback, WandbCallback, get_default_callbacks, get_verbose_callbacks
211
+ )
212
+
213
+ # Schedulers
214
+ from .schedulers import (
215
+ WarmupScheduler, CosineAnnealingWithWarmup, LinearDecayWithWarmup,
216
+ PolynomialDecayWithWarmup, ConstantWithWarmup, OneCycleLRWithWarmup, get_scheduler
217
+ )
218
+
219
+ # Metrics
220
+ from .metrics import (
221
+ compute_perplexity, compute_accuracy, compute_top_k_accuracy,
222
+ compute_bleu, compute_rouge_l, compute_diversity, MetricsCalculator
223
+ )
224
+
225
+ # Generation
226
+ from .generation import TextGenerator, generate
227
+
228
+ # Tokenizers
229
+ from .tokenizers import CharacterTokenizer, WordTokenizer, BPETokenizer, get_tokenizer
230
+
231
+ # Distributed
232
+ from .distributed import (
233
+ is_distributed, get_rank, get_world_size, is_main_process,
234
+ setup_distributed, cleanup_distributed, wrap_model_ddp, get_distributed_sampler
235
+ )
236
+
237
+ # Logging
238
+ from .logging_utils import (
239
+ setup_logging, get_logger, TrainingLogger, ProgressTracker,
240
+ print_banner, print_metrics
241
+ )
242
+
243
+ # Utilities
244
+ from .utils import (
245
+ set_seed, get_device, count_parameters, count_lora_parameters,
246
+ encode_text, decode_tokens, SimpleTokenizer, create_attention_mask,
247
+ pad_sequences, truncate_sequences, compute_perplexity, compute_bleu_score,
248
+ format_time, format_size, get_model_size, print_model_summary,
249
+ save_model_info, load_model_info, log_gpu_memory, cleanup_gpu_memory
250
+ )
251
+
252
+ # Authentication
253
+ from .auth import (
254
+ get_api_key, set_api_key, verify_api_key, check_usage,
255
+ interactive_login, logout, print_usage_info,
256
+ AuthenticationError, UsageLimitError, require_auth
257
+ )
258
+
259
+ # CLI
260
+ from .cli import main
261
+
262
+ # Facades (Quick Start API)
263
+ from .facade import LoRATrainer, QLoRATrainer, ChatModel, deploy
264
+
265
+ __all__ = [
266
+ # Models
267
+ "LoRALanguageModel", "LoRALinear", "MultiHeadAttention", "TransformerBlock",
268
+ "FastLoRALanguageModel", "FastMultiHeadAttention", "FastTransformerBlock",
269
+ "FastLoRALanguageModel", "FastMultiHeadAttention", "FastTransformerBlock",
270
+
271
+ # Optimizations
272
+ "OptimizationConfig", "QuantizedLinear", "LoRALinear4bit",
273
+ "RotaryPositionEmbedding", "MemoryEfficientAttention",
274
+ "fused_cross_entropy", "checkpoint", "MixedPrecisionTrainer",
275
+ "get_memory_stats", "cleanup_memory",
276
+
277
+ # Configuration
278
+ "Config", "ModelConfig", "TrainingConfig", "DataConfig", "LoRAConfig",
279
+ "default_config", "load_config", "save_config", "get_preset_config", "validate_config",
280
+
281
+ # Data
282
+ "TextDataset", "LanguageModelingDataset", "DataCollator",
283
+ "load_text_file", "load_json_file", "create_data_loader", "split_dataset",
284
+ "SimpleTokenizer", "create_sample_dataset", "load_dataset_from_config",
285
+
286
+ # Training
287
+ "Trainer", "FastTrainer", "EarlyStopping", "MetricsTracker", "ModelCheckpoint",
288
+ "create_trainer", "create_fast_trainer",
289
+
290
+ # Fine-tuning
291
+ "finetune", "finetune_from_config", "FineTuneConfig", "train", "fine_tune",
292
+
293
+ # Utilities
294
+ "set_seed", "get_device", "count_parameters", "count_lora_parameters",
295
+ "encode_text", "decode_tokens", "create_attention_mask",
296
+ "pad_sequences", "truncate_sequences", "compute_perplexity", "compute_bleu_score",
297
+ "format_time", "format_size", "get_model_size", "print_model_summary",
298
+ "save_model_info", "load_model_info", "log_gpu_memory", "cleanup_gpu_memory",
299
+
300
+ # Authentication
301
+ "get_api_key", "set_api_key", "verify_api_key", "check_usage",
302
+ "interactive_login", "logout", "print_usage_info",
303
+ "AuthenticationError", "UsageLimitError", "require_auth",
304
+
305
+ # CLI
306
+ "main",
307
+
308
+ # Facades
309
+ "LoRATrainer", "QLoRATrainer", "ChatModel", "deploy",
310
+
311
+ # Version
312
+ "__version__"
313
+ ]
314
+
315
+
@@ -0,0 +1,132 @@
1
+ """
2
+ Accelerator module for Langtune.
3
+
4
+ This module provides an interface to access high-performance custom kernels
5
+ from langtrain-server if they are available in the environment.
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing import Optional, Tuple, Any
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Autograd Functions
16
+ class FusedRMSNormFunction(torch.autograd.Function):
17
+ @staticmethod
18
+ def forward(ctx, input, weight, eps, kernels):
19
+ ctx.save_for_backward(input, weight)
20
+ ctx.eps = eps
21
+ ctx.kernels = kernels
22
+ return kernels.fused_rmsnorm_forward(input, weight, eps)
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_output):
26
+ input, weight = ctx.saved_tensors
27
+ grad_input, grad_weight = ctx.kernels.fused_rmsnorm_backward(grad_output, input, weight, ctx.eps)
28
+ return grad_input, grad_weight, None, None
29
+
30
+ class FusedLoRAFunction(torch.autograd.Function):
31
+ @staticmethod
32
+ def forward(ctx, input, base_weight, lora_A, lora_B, scaling, kernels):
33
+ ctx.save_for_backward(input, base_weight, lora_A, lora_B)
34
+ ctx.scaling = scaling
35
+ ctx.kernels = kernels
36
+ return kernels.fused_lora_forward(input, base_weight, lora_A, lora_B, scaling)
37
+
38
+ @staticmethod
39
+ def backward(ctx, grad_output):
40
+ input, base_weight, lora_A, lora_B = ctx.saved_tensors
41
+ grad_input, grad_A, grad_B = ctx.kernels.lora_backward(grad_output, input, lora_A, lora_B, base_weight, ctx.scaling)
42
+ return grad_input, None, grad_A, grad_B, None, None
43
+
44
+ class Accelerator:
45
+ """
46
+ Manages access to accelerated kernels.
47
+ """
48
+
49
+ _instance = None
50
+
51
+ def __new__(cls):
52
+ if cls._instance is None:
53
+ cls._instance = super(Accelerator, cls).__new__(cls)
54
+ cls._instance._initialize()
55
+ return cls._instance
56
+
57
+ def _initialize(self):
58
+ self.available = False
59
+ self.kernels = None
60
+
61
+ try:
62
+ import langtrain_cuda
63
+ self.kernels = langtrain_cuda
64
+ self.available = True
65
+ logger.info("Langtrain high-performance kernels detected and enabled.")
66
+ except ImportError:
67
+ logger.info("Langtrain kernels not found. Using standard PyTorch implementations.")
68
+
69
+ def is_available(self) -> bool:
70
+ return self.available
71
+
72
+ def fused_attention(
73
+ self,
74
+ query: torch.Tensor,
75
+ key: torch.Tensor,
76
+ value: torch.Tensor,
77
+ is_causal: bool = True,
78
+ scale: float = None
79
+ ) -> torch.Tensor:
80
+ """
81
+ Run fused attention forward pass.
82
+ """
83
+ if self.available and self.kernels:
84
+ if scale is None:
85
+ scale = query.size(-1) ** -0.5
86
+ # Attention fallback to PyTorch for backward safety if kernel lacks it
87
+ # For now, we use standard SDPA which is Flash-enabled in PT 2.0+
88
+ return torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, scale=scale)
89
+ else:
90
+ return torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, scale=scale)
91
+
92
+ def fused_rmsnorm(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ weight: torch.Tensor,
96
+ eps: float = 1e-6
97
+ ) -> torch.Tensor:
98
+ """
99
+ Run fused RMSNorm.
100
+ """
101
+ if self.available and self.kernels:
102
+ return FusedRMSNormFunction.apply(hidden_states, weight, eps, self.kernels)
103
+ return None
104
+
105
+ def fused_mlp(
106
+ self,
107
+ hidden_states: torch.Tensor,
108
+ gate_weight: torch.Tensor,
109
+ up_weight: torch.Tensor,
110
+ down_weight: torch.Tensor
111
+ ) -> torch.Tensor:
112
+ """
113
+ Run fused SwiGLU MLP.
114
+ """
115
+ # Fallback to PyTorch decomposotion as we lack fused backward kernel for MLP
116
+ return (torch.nn.functional.silu(torch.nn.functional.linear(hidden_states, gate_weight)) *
117
+ torch.nn.functional.linear(hidden_states, up_weight)).matmul(down_weight.t())
118
+
119
+ def fused_lora(
120
+ self,
121
+ x: torch.Tensor,
122
+ base_weight: torch.Tensor,
123
+ lora_A: torch.Tensor,
124
+ lora_B: torch.Tensor,
125
+ scaling: float
126
+ ) -> torch.Tensor:
127
+ """
128
+ Run fused LoRA forward.
129
+ """
130
+ if self.available and self.kernels:
131
+ return FusedLoRAFunction.apply(x, base_weight, lora_A, lora_B, scaling, self.kernels)
132
+ return None