langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/__init__.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Langtune: Efficient LoRA Fine-Tuning for Text LLMs
|
|
3
|
+
|
|
4
|
+
This package provides tools and modules for efficient fine-tuning of large language models (LLMs) on text data using Low-Rank Adaptation (LoRA).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
|
|
10
|
+
__version__ = "0.1.2"
|
|
11
|
+
|
|
12
|
+
# Banner display control
|
|
13
|
+
_BANNER_SHOWN = False
|
|
14
|
+
_SHOW_BANNER = os.environ.get("LANGTUNE_NO_BANNER", "0") != "1"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _check_tpu_available() -> bool:
|
|
18
|
+
"""Check if Google TPU is available via torch_xla."""
|
|
19
|
+
try:
|
|
20
|
+
import torch_xla
|
|
21
|
+
import torch_xla.core.xla_model as xm
|
|
22
|
+
# Try to get a TPU device
|
|
23
|
+
device = xm.xla_device()
|
|
24
|
+
return "TPU" in str(device) or "xla" in str(device).lower()
|
|
25
|
+
except (ImportError, RuntimeError, Exception):
|
|
26
|
+
return False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _get_tpu_version() -> str:
|
|
30
|
+
"""Get TPU version if available."""
|
|
31
|
+
try:
|
|
32
|
+
import torch_xla
|
|
33
|
+
# Try to detect TPU version from environment or device info
|
|
34
|
+
import os
|
|
35
|
+
tpu_name = os.environ.get("TPU_NAME", "")
|
|
36
|
+
if "v4" in tpu_name.lower():
|
|
37
|
+
return "4"
|
|
38
|
+
elif "v3" in tpu_name.lower():
|
|
39
|
+
return "3"
|
|
40
|
+
elif "v2" in tpu_name.lower():
|
|
41
|
+
return "2"
|
|
42
|
+
return "4" # Default to v4 for newer TPUs
|
|
43
|
+
except:
|
|
44
|
+
return "?"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _show_welcome_banner():
|
|
48
|
+
"""Display a beautiful welcome banner on first import."""
|
|
49
|
+
global _BANNER_SHOWN
|
|
50
|
+
|
|
51
|
+
if _BANNER_SHOWN or not _SHOW_BANNER:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
_BANNER_SHOWN = True
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
from rich.console import Console
|
|
58
|
+
from rich.panel import Panel
|
|
59
|
+
from rich.text import Text
|
|
60
|
+
from rich import box
|
|
61
|
+
import torch
|
|
62
|
+
|
|
63
|
+
console = Console()
|
|
64
|
+
|
|
65
|
+
# Check GPU/TPU availability with detailed info
|
|
66
|
+
# Check for NVIDIA CUDA
|
|
67
|
+
if torch.cuda.is_available():
|
|
68
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
69
|
+
gpu_count = torch.cuda.device_count()
|
|
70
|
+
|
|
71
|
+
# Get NVIDIA-specific details
|
|
72
|
+
try:
|
|
73
|
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
|
74
|
+
cuda_version = torch.version.cuda
|
|
75
|
+
|
|
76
|
+
if gpu_count > 1:
|
|
77
|
+
gpu_info = f"✓ NVIDIA: {gpu_name} × {gpu_count} ({gpu_memory:.0f}GB each, CUDA {cuda_version})"
|
|
78
|
+
else:
|
|
79
|
+
gpu_info = f"✓ NVIDIA: {gpu_name} ({gpu_memory:.0f}GB, CUDA {cuda_version})"
|
|
80
|
+
except:
|
|
81
|
+
gpu_info = f"✓ NVIDIA: {gpu_name}"
|
|
82
|
+
|
|
83
|
+
gpu_style = "green"
|
|
84
|
+
# Check for Google TPU (via torch_xla)
|
|
85
|
+
elif _check_tpu_available():
|
|
86
|
+
try:
|
|
87
|
+
import torch_xla.core.xla_model as xm
|
|
88
|
+
tpu_count = xm.xrt_world_size()
|
|
89
|
+
gpu_info = f"✓ TPU: Google Cloud TPU v{_get_tpu_version()} ({tpu_count} cores)"
|
|
90
|
+
except:
|
|
91
|
+
gpu_info = "✓ TPU: Google Cloud TPU"
|
|
92
|
+
gpu_style = "green"
|
|
93
|
+
# Check for Apple MPS
|
|
94
|
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
95
|
+
gpu_info = "✓ Apple: Metal Performance Shaders (MPS)"
|
|
96
|
+
gpu_style = "green"
|
|
97
|
+
else:
|
|
98
|
+
gpu_info = "○ Accelerator: Not available (CPU mode)"
|
|
99
|
+
gpu_style = "yellow"
|
|
100
|
+
|
|
101
|
+
# Create banner content
|
|
102
|
+
banner_text = Text()
|
|
103
|
+
|
|
104
|
+
# Logo/Title
|
|
105
|
+
banner_text.append("╔═══════════════════════════════════════════════════════════╗\n", style="cyan bold")
|
|
106
|
+
banner_text.append("║", style="cyan bold")
|
|
107
|
+
banner_text.append(" ", style="")
|
|
108
|
+
banner_text.append("LANGTUNE", style="bold magenta")
|
|
109
|
+
banner_text.append(" ", style="")
|
|
110
|
+
banner_text.append("║\n", style="cyan bold")
|
|
111
|
+
banner_text.append("║", style="cyan bold")
|
|
112
|
+
banner_text.append(" Efficient LoRA Fine-Tuning for LLMs ", style="dim")
|
|
113
|
+
banner_text.append("║\n", style="cyan bold")
|
|
114
|
+
banner_text.append("╚═══════════════════════════════════════════════════════════╝\n", style="cyan bold")
|
|
115
|
+
|
|
116
|
+
# Info section
|
|
117
|
+
banner_text.append("\n")
|
|
118
|
+
banner_text.append(" 📦 Version: ", style="dim")
|
|
119
|
+
banner_text.append(f"v{__version__}\n", style="cyan")
|
|
120
|
+
banner_text.append(f" 🖥️ {gpu_info}\n", style=gpu_style)
|
|
121
|
+
banner_text.append(" 📚 Docs: ", style="dim")
|
|
122
|
+
banner_text.append("https://github.com/langtrain-ai/langtune\n", style="blue underline")
|
|
123
|
+
|
|
124
|
+
# Quick start
|
|
125
|
+
banner_text.append("\n 🚀 ", style="")
|
|
126
|
+
banner_text.append("Quick Start:\n", style="bold")
|
|
127
|
+
banner_text.append(" 1. langtune auth login ", style="cyan")
|
|
128
|
+
banner_text.append("# Get key at langtrain.xyz\n", style="dim")
|
|
129
|
+
banner_text.append(" 2. langtune train --preset small --train-file data.txt\n", style="cyan")
|
|
130
|
+
|
|
131
|
+
# Tips
|
|
132
|
+
banner_text.append("\n 💡 ", style="")
|
|
133
|
+
banner_text.append("Tip: ", style="yellow bold")
|
|
134
|
+
banner_text.append("Set LANGTUNE_NO_BANNER=1 to disable this message\n", style="dim")
|
|
135
|
+
|
|
136
|
+
console.print(banner_text)
|
|
137
|
+
|
|
138
|
+
except ImportError:
|
|
139
|
+
# Fallback to simple banner if rich is not available
|
|
140
|
+
print(f"""
|
|
141
|
+
╔═══════════════════════════════════════════════════════════╗
|
|
142
|
+
║ LANGTUNE ║
|
|
143
|
+
║ Efficient LoRA Fine-Tuning for LLMs ║
|
|
144
|
+
╚═══════════════════════════════════════════════════════════╝
|
|
145
|
+
|
|
146
|
+
📦 Version: v{__version__}
|
|
147
|
+
📚 Docs: https://langtrain.xyz
|
|
148
|
+
|
|
149
|
+
🚀 Quick Start:
|
|
150
|
+
1. langtune auth login # Get key at langtrain.xyz
|
|
151
|
+
2. langtune train --preset small --train-file data.txt
|
|
152
|
+
|
|
153
|
+
💡 Tip: Set LANGTUNE_NO_BANNER=1 to disable this message
|
|
154
|
+
""")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# Show banner on import (unless in non-interactive mode)
|
|
158
|
+
if sys.stdout.isatty():
|
|
159
|
+
_show_welcome_banner()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# Core models
|
|
163
|
+
from .models import (
|
|
164
|
+
LoRALanguageModel, LoRALinear, MultiHeadAttention, TransformerBlock,
|
|
165
|
+
FastLoRALanguageModel, FastMultiHeadAttention, FastTransformerBlock
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Optimizations
|
|
169
|
+
from .optimizations import (
|
|
170
|
+
OptimizationConfig, QuantizedLinear, LoRALinear4bit,
|
|
171
|
+
RotaryPositionEmbedding, MemoryEfficientAttention,
|
|
172
|
+
fused_cross_entropy, checkpoint, MixedPrecisionTrainer,
|
|
173
|
+
get_memory_stats, cleanup_memory
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Configuration
|
|
177
|
+
from .config import (
|
|
178
|
+
Config, ModelConfig, TrainingConfig, DataConfig, LoRAConfig,
|
|
179
|
+
default_config, load_config, save_config, get_preset_config, validate_config
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Data handling
|
|
183
|
+
from .data import (
|
|
184
|
+
TextDataset, LanguageModelingDataset, DataCollator,
|
|
185
|
+
load_text_file, load_json_file, create_data_loader, split_dataset,
|
|
186
|
+
SimpleTokenizer, create_sample_dataset, load_dataset_from_config
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Training
|
|
190
|
+
from .trainer import (
|
|
191
|
+
Trainer, FastTrainer, EarlyStopping, MetricsTracker, ModelCheckpoint,
|
|
192
|
+
create_trainer, create_fast_trainer
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Fine-tuning (best-practice API)
|
|
196
|
+
from .finetune import (
|
|
197
|
+
finetune as local_finetune, finetune_from_config, FineTuneConfig
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Client SDK
|
|
201
|
+
from .client import LangtuneClient, FineTuneJob, JobStatus, Model, APIError, get_client
|
|
202
|
+
|
|
203
|
+
# High-level API (server + local)
|
|
204
|
+
from .api import finetune, generate, chat, list_models, list_jobs, get_job, cancel_job
|
|
205
|
+
|
|
206
|
+
# Callbacks
|
|
207
|
+
from .callbacks import (
|
|
208
|
+
Callback, CallbackList, ProgressCallback, LearningRateMonitorCallback,
|
|
209
|
+
GradientMonitorCallback, ModelSizeCallback, TimerCallback, SaveHistoryCallback,
|
|
210
|
+
MemoryMonitorCallback, WandbCallback, get_default_callbacks, get_verbose_callbacks
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Schedulers
|
|
214
|
+
from .schedulers import (
|
|
215
|
+
WarmupScheduler, CosineAnnealingWithWarmup, LinearDecayWithWarmup,
|
|
216
|
+
PolynomialDecayWithWarmup, ConstantWithWarmup, OneCycleLRWithWarmup, get_scheduler
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Metrics
|
|
220
|
+
from .metrics import (
|
|
221
|
+
compute_perplexity, compute_accuracy, compute_top_k_accuracy,
|
|
222
|
+
compute_bleu, compute_rouge_l, compute_diversity, MetricsCalculator
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Generation
|
|
226
|
+
from .generation import TextGenerator, generate
|
|
227
|
+
|
|
228
|
+
# Tokenizers
|
|
229
|
+
from .tokenizers import CharacterTokenizer, WordTokenizer, BPETokenizer, get_tokenizer
|
|
230
|
+
|
|
231
|
+
# Distributed
|
|
232
|
+
from .distributed import (
|
|
233
|
+
is_distributed, get_rank, get_world_size, is_main_process,
|
|
234
|
+
setup_distributed, cleanup_distributed, wrap_model_ddp, get_distributed_sampler
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Logging
|
|
238
|
+
from .logging_utils import (
|
|
239
|
+
setup_logging, get_logger, TrainingLogger, ProgressTracker,
|
|
240
|
+
print_banner, print_metrics
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Utilities
|
|
244
|
+
from .utils import (
|
|
245
|
+
set_seed, get_device, count_parameters, count_lora_parameters,
|
|
246
|
+
encode_text, decode_tokens, SimpleTokenizer, create_attention_mask,
|
|
247
|
+
pad_sequences, truncate_sequences, compute_perplexity, compute_bleu_score,
|
|
248
|
+
format_time, format_size, get_model_size, print_model_summary,
|
|
249
|
+
save_model_info, load_model_info, log_gpu_memory, cleanup_gpu_memory
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Authentication
|
|
253
|
+
from .auth import (
|
|
254
|
+
get_api_key, set_api_key, verify_api_key, check_usage,
|
|
255
|
+
interactive_login, logout, print_usage_info,
|
|
256
|
+
AuthenticationError, UsageLimitError, require_auth
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# CLI
|
|
260
|
+
from .cli import main
|
|
261
|
+
|
|
262
|
+
# Facades (Quick Start API)
|
|
263
|
+
from .facade import LoRATrainer, QLoRATrainer, ChatModel, deploy
|
|
264
|
+
|
|
265
|
+
__all__ = [
|
|
266
|
+
# Models
|
|
267
|
+
"LoRALanguageModel", "LoRALinear", "MultiHeadAttention", "TransformerBlock",
|
|
268
|
+
"FastLoRALanguageModel", "FastMultiHeadAttention", "FastTransformerBlock",
|
|
269
|
+
"FastLoRALanguageModel", "FastMultiHeadAttention", "FastTransformerBlock",
|
|
270
|
+
|
|
271
|
+
# Optimizations
|
|
272
|
+
"OptimizationConfig", "QuantizedLinear", "LoRALinear4bit",
|
|
273
|
+
"RotaryPositionEmbedding", "MemoryEfficientAttention",
|
|
274
|
+
"fused_cross_entropy", "checkpoint", "MixedPrecisionTrainer",
|
|
275
|
+
"get_memory_stats", "cleanup_memory",
|
|
276
|
+
|
|
277
|
+
# Configuration
|
|
278
|
+
"Config", "ModelConfig", "TrainingConfig", "DataConfig", "LoRAConfig",
|
|
279
|
+
"default_config", "load_config", "save_config", "get_preset_config", "validate_config",
|
|
280
|
+
|
|
281
|
+
# Data
|
|
282
|
+
"TextDataset", "LanguageModelingDataset", "DataCollator",
|
|
283
|
+
"load_text_file", "load_json_file", "create_data_loader", "split_dataset",
|
|
284
|
+
"SimpleTokenizer", "create_sample_dataset", "load_dataset_from_config",
|
|
285
|
+
|
|
286
|
+
# Training
|
|
287
|
+
"Trainer", "FastTrainer", "EarlyStopping", "MetricsTracker", "ModelCheckpoint",
|
|
288
|
+
"create_trainer", "create_fast_trainer",
|
|
289
|
+
|
|
290
|
+
# Fine-tuning
|
|
291
|
+
"finetune", "finetune_from_config", "FineTuneConfig", "train", "fine_tune",
|
|
292
|
+
|
|
293
|
+
# Utilities
|
|
294
|
+
"set_seed", "get_device", "count_parameters", "count_lora_parameters",
|
|
295
|
+
"encode_text", "decode_tokens", "create_attention_mask",
|
|
296
|
+
"pad_sequences", "truncate_sequences", "compute_perplexity", "compute_bleu_score",
|
|
297
|
+
"format_time", "format_size", "get_model_size", "print_model_summary",
|
|
298
|
+
"save_model_info", "load_model_info", "log_gpu_memory", "cleanup_gpu_memory",
|
|
299
|
+
|
|
300
|
+
# Authentication
|
|
301
|
+
"get_api_key", "set_api_key", "verify_api_key", "check_usage",
|
|
302
|
+
"interactive_login", "logout", "print_usage_info",
|
|
303
|
+
"AuthenticationError", "UsageLimitError", "require_auth",
|
|
304
|
+
|
|
305
|
+
# CLI
|
|
306
|
+
"main",
|
|
307
|
+
|
|
308
|
+
# Facades
|
|
309
|
+
"LoRATrainer", "QLoRATrainer", "ChatModel", "deploy",
|
|
310
|
+
|
|
311
|
+
# Version
|
|
312
|
+
"__version__"
|
|
313
|
+
]
|
|
314
|
+
|
|
315
|
+
|
langtune/acceleration.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Accelerator module for Langtune.
|
|
3
|
+
|
|
4
|
+
This module provides an interface to access high-performance custom kernels
|
|
5
|
+
from langtrain-server if they are available in the environment.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from typing import Optional, Tuple, Any
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Autograd Functions
|
|
16
|
+
class FusedRMSNormFunction(torch.autograd.Function):
|
|
17
|
+
@staticmethod
|
|
18
|
+
def forward(ctx, input, weight, eps, kernels):
|
|
19
|
+
ctx.save_for_backward(input, weight)
|
|
20
|
+
ctx.eps = eps
|
|
21
|
+
ctx.kernels = kernels
|
|
22
|
+
return kernels.fused_rmsnorm_forward(input, weight, eps)
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def backward(ctx, grad_output):
|
|
26
|
+
input, weight = ctx.saved_tensors
|
|
27
|
+
grad_input, grad_weight = ctx.kernels.fused_rmsnorm_backward(grad_output, input, weight, ctx.eps)
|
|
28
|
+
return grad_input, grad_weight, None, None
|
|
29
|
+
|
|
30
|
+
class FusedLoRAFunction(torch.autograd.Function):
|
|
31
|
+
@staticmethod
|
|
32
|
+
def forward(ctx, input, base_weight, lora_A, lora_B, scaling, kernels):
|
|
33
|
+
ctx.save_for_backward(input, base_weight, lora_A, lora_B)
|
|
34
|
+
ctx.scaling = scaling
|
|
35
|
+
ctx.kernels = kernels
|
|
36
|
+
return kernels.fused_lora_forward(input, base_weight, lora_A, lora_B, scaling)
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def backward(ctx, grad_output):
|
|
40
|
+
input, base_weight, lora_A, lora_B = ctx.saved_tensors
|
|
41
|
+
grad_input, grad_A, grad_B = ctx.kernels.lora_backward(grad_output, input, lora_A, lora_B, base_weight, ctx.scaling)
|
|
42
|
+
return grad_input, None, grad_A, grad_B, None, None
|
|
43
|
+
|
|
44
|
+
class Accelerator:
|
|
45
|
+
"""
|
|
46
|
+
Manages access to accelerated kernels.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
_instance = None
|
|
50
|
+
|
|
51
|
+
def __new__(cls):
|
|
52
|
+
if cls._instance is None:
|
|
53
|
+
cls._instance = super(Accelerator, cls).__new__(cls)
|
|
54
|
+
cls._instance._initialize()
|
|
55
|
+
return cls._instance
|
|
56
|
+
|
|
57
|
+
def _initialize(self):
|
|
58
|
+
self.available = False
|
|
59
|
+
self.kernels = None
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
import langtrain_cuda
|
|
63
|
+
self.kernels = langtrain_cuda
|
|
64
|
+
self.available = True
|
|
65
|
+
logger.info("Langtrain high-performance kernels detected and enabled.")
|
|
66
|
+
except ImportError:
|
|
67
|
+
logger.info("Langtrain kernels not found. Using standard PyTorch implementations.")
|
|
68
|
+
|
|
69
|
+
def is_available(self) -> bool:
|
|
70
|
+
return self.available
|
|
71
|
+
|
|
72
|
+
def fused_attention(
|
|
73
|
+
self,
|
|
74
|
+
query: torch.Tensor,
|
|
75
|
+
key: torch.Tensor,
|
|
76
|
+
value: torch.Tensor,
|
|
77
|
+
is_causal: bool = True,
|
|
78
|
+
scale: float = None
|
|
79
|
+
) -> torch.Tensor:
|
|
80
|
+
"""
|
|
81
|
+
Run fused attention forward pass.
|
|
82
|
+
"""
|
|
83
|
+
if self.available and self.kernels:
|
|
84
|
+
if scale is None:
|
|
85
|
+
scale = query.size(-1) ** -0.5
|
|
86
|
+
# Attention fallback to PyTorch for backward safety if kernel lacks it
|
|
87
|
+
# For now, we use standard SDPA which is Flash-enabled in PT 2.0+
|
|
88
|
+
return torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, scale=scale)
|
|
89
|
+
else:
|
|
90
|
+
return torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, scale=scale)
|
|
91
|
+
|
|
92
|
+
def fused_rmsnorm(
|
|
93
|
+
self,
|
|
94
|
+
hidden_states: torch.Tensor,
|
|
95
|
+
weight: torch.Tensor,
|
|
96
|
+
eps: float = 1e-6
|
|
97
|
+
) -> torch.Tensor:
|
|
98
|
+
"""
|
|
99
|
+
Run fused RMSNorm.
|
|
100
|
+
"""
|
|
101
|
+
if self.available and self.kernels:
|
|
102
|
+
return FusedRMSNormFunction.apply(hidden_states, weight, eps, self.kernels)
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
def fused_mlp(
|
|
106
|
+
self,
|
|
107
|
+
hidden_states: torch.Tensor,
|
|
108
|
+
gate_weight: torch.Tensor,
|
|
109
|
+
up_weight: torch.Tensor,
|
|
110
|
+
down_weight: torch.Tensor
|
|
111
|
+
) -> torch.Tensor:
|
|
112
|
+
"""
|
|
113
|
+
Run fused SwiGLU MLP.
|
|
114
|
+
"""
|
|
115
|
+
# Fallback to PyTorch decomposotion as we lack fused backward kernel for MLP
|
|
116
|
+
return (torch.nn.functional.silu(torch.nn.functional.linear(hidden_states, gate_weight)) *
|
|
117
|
+
torch.nn.functional.linear(hidden_states, up_weight)).matmul(down_weight.t())
|
|
118
|
+
|
|
119
|
+
def fused_lora(
|
|
120
|
+
self,
|
|
121
|
+
x: torch.Tensor,
|
|
122
|
+
base_weight: torch.Tensor,
|
|
123
|
+
lora_A: torch.Tensor,
|
|
124
|
+
lora_B: torch.Tensor,
|
|
125
|
+
scaling: float
|
|
126
|
+
) -> torch.Tensor:
|
|
127
|
+
"""
|
|
128
|
+
Run fused LoRA forward.
|
|
129
|
+
"""
|
|
130
|
+
if self.available and self.kernels:
|
|
131
|
+
return FusedLoRAFunction.apply(x, base_weight, lora_A, lora_B, scaling, self.kernels)
|
|
132
|
+
return None
|