langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,870 @@
|
|
|
1
|
+
"""
|
|
2
|
+
optimizations.py: Efficient fine-tuning optimizations for Langtune
|
|
3
|
+
|
|
4
|
+
This module implements memory-efficient and speed-optimized techniques
|
|
5
|
+
inspired by Unsloth, including:
|
|
6
|
+
- 4-bit quantization (QLoRA style)
|
|
7
|
+
- Rotary Position Embeddings (RoPE)
|
|
8
|
+
- Fused cross-entropy loss
|
|
9
|
+
- Memory-efficient attention
|
|
10
|
+
- Gradient checkpointing utilities
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import math
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
from torch.cuda.amp import autocast, GradScaler
|
|
18
|
+
from typing import Optional, Tuple, Dict, Any, Union
|
|
19
|
+
import logging
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Check for available optimizations
|
|
24
|
+
FLASH_ATTENTION_AVAILABLE = False
|
|
25
|
+
try:
|
|
26
|
+
from flash_attn import flash_attn_func
|
|
27
|
+
FLASH_ATTENTION_AVAILABLE = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
# Check for bitsandbytes (optional 4-bit support)
|
|
32
|
+
BITSANDBYTES_AVAILABLE = False
|
|
33
|
+
try:
|
|
34
|
+
import bitsandbytes as bnb
|
|
35
|
+
BITSANDBYTES_AVAILABLE = True
|
|
36
|
+
except ImportError:
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# Quantization Utilities
|
|
42
|
+
# =============================================================================
|
|
43
|
+
|
|
44
|
+
def quantize_tensor_4bit(
|
|
45
|
+
tensor: torch.Tensor,
|
|
46
|
+
group_size: int = 64
|
|
47
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
48
|
+
"""
|
|
49
|
+
Quantize a tensor to 4-bit representation.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
tensor: Input tensor to quantize
|
|
53
|
+
group_size: Number of elements per quantization group
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Tuple of (quantized_data, scales, zero_points)
|
|
57
|
+
"""
|
|
58
|
+
original_shape = tensor.shape
|
|
59
|
+
tensor = tensor.reshape(-1, group_size)
|
|
60
|
+
|
|
61
|
+
# Compute min/max per group
|
|
62
|
+
min_vals = tensor.min(dim=1, keepdim=True)[0]
|
|
63
|
+
max_vals = tensor.max(dim=1, keepdim=True)[0]
|
|
64
|
+
|
|
65
|
+
# Compute scale and zero point
|
|
66
|
+
scales = (max_vals - min_vals) / 15.0 # 4-bit = 16 levels
|
|
67
|
+
zero_points = min_vals
|
|
68
|
+
|
|
69
|
+
# Quantize
|
|
70
|
+
quantized = torch.clamp(
|
|
71
|
+
torch.round((tensor - zero_points) / (scales + 1e-8)),
|
|
72
|
+
0, 15
|
|
73
|
+
).to(torch.uint8)
|
|
74
|
+
|
|
75
|
+
# Pack two 4-bit values into one uint8
|
|
76
|
+
packed = quantized[:, ::2] | (quantized[:, 1::2] << 4)
|
|
77
|
+
|
|
78
|
+
return packed, scales.squeeze(-1), zero_points.squeeze(-1)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def dequantize_tensor_4bit(
|
|
82
|
+
packed: torch.Tensor,
|
|
83
|
+
scales: torch.Tensor,
|
|
84
|
+
zero_points: torch.Tensor,
|
|
85
|
+
group_size: int = 64,
|
|
86
|
+
output_shape: Tuple[int, ...] = None
|
|
87
|
+
) -> torch.Tensor:
|
|
88
|
+
"""
|
|
89
|
+
Dequantize a 4-bit tensor back to float.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
packed: Packed 4-bit tensor
|
|
93
|
+
scales: Quantization scales
|
|
94
|
+
zero_points: Quantization zero points
|
|
95
|
+
group_size: Number of elements per group
|
|
96
|
+
output_shape: Original tensor shape
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Dequantized float tensor
|
|
100
|
+
"""
|
|
101
|
+
# Unpack 4-bit values
|
|
102
|
+
low = packed & 0x0F
|
|
103
|
+
high = (packed >> 4) & 0x0F
|
|
104
|
+
|
|
105
|
+
# Interleave to get original order
|
|
106
|
+
batch_size = packed.shape[0]
|
|
107
|
+
unpacked = torch.zeros(batch_size, group_size, device=packed.device, dtype=torch.float32)
|
|
108
|
+
unpacked[:, ::2] = low.float()
|
|
109
|
+
unpacked[:, 1::2] = high.float()
|
|
110
|
+
|
|
111
|
+
# Dequantize
|
|
112
|
+
dequantized = unpacked * scales.unsqueeze(-1) + zero_points.unsqueeze(-1)
|
|
113
|
+
|
|
114
|
+
if output_shape is not None:
|
|
115
|
+
dequantized = dequantized.reshape(output_shape)
|
|
116
|
+
|
|
117
|
+
return dequantized
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class QuantizedLinear(nn.Module):
|
|
121
|
+
"""
|
|
122
|
+
4-bit quantized linear layer with efficient on-the-fly dequantization.
|
|
123
|
+
|
|
124
|
+
This provides significant memory savings (75%) compared to FP16,
|
|
125
|
+
with minimal accuracy loss.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
in_features: int,
|
|
131
|
+
out_features: int,
|
|
132
|
+
bias: bool = True,
|
|
133
|
+
group_size: int = 64,
|
|
134
|
+
compute_dtype: torch.dtype = torch.float16
|
|
135
|
+
):
|
|
136
|
+
super().__init__()
|
|
137
|
+
self.in_features = in_features
|
|
138
|
+
self.out_features = out_features
|
|
139
|
+
self.group_size = group_size
|
|
140
|
+
self.compute_dtype = compute_dtype
|
|
141
|
+
|
|
142
|
+
# Ensure dimensions are compatible with group size
|
|
143
|
+
assert in_features % group_size == 0, f"in_features ({in_features}) must be divisible by group_size ({group_size})"
|
|
144
|
+
|
|
145
|
+
# Number of groups
|
|
146
|
+
self.num_groups = in_features // group_size
|
|
147
|
+
|
|
148
|
+
# Quantized weight storage (packed 4-bit)
|
|
149
|
+
self.register_buffer(
|
|
150
|
+
'weight_packed',
|
|
151
|
+
torch.zeros(out_features * self.num_groups, group_size // 2, dtype=torch.uint8)
|
|
152
|
+
)
|
|
153
|
+
self.register_buffer(
|
|
154
|
+
'weight_scales',
|
|
155
|
+
torch.ones(out_features * self.num_groups, dtype=compute_dtype)
|
|
156
|
+
)
|
|
157
|
+
self.register_buffer(
|
|
158
|
+
'weight_zeros',
|
|
159
|
+
torch.zeros(out_features * self.num_groups, dtype=compute_dtype)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if bias:
|
|
163
|
+
self.bias = nn.Parameter(torch.zeros(out_features, dtype=compute_dtype))
|
|
164
|
+
else:
|
|
165
|
+
self.register_parameter('bias', None)
|
|
166
|
+
|
|
167
|
+
self._is_initialized = False
|
|
168
|
+
|
|
169
|
+
def initialize_from_weight(self, weight: torch.Tensor):
|
|
170
|
+
"""Initialize quantized weights from a float weight tensor."""
|
|
171
|
+
assert weight.shape == (self.out_features, self.in_features)
|
|
172
|
+
|
|
173
|
+
# Reshape for group-wise quantization
|
|
174
|
+
weight_grouped = weight.reshape(self.out_features * self.num_groups, self.group_size)
|
|
175
|
+
|
|
176
|
+
# Quantize
|
|
177
|
+
packed, scales, zeros = quantize_tensor_4bit(weight_grouped, self.group_size)
|
|
178
|
+
|
|
179
|
+
self.weight_packed.copy_(packed)
|
|
180
|
+
self.weight_scales.copy_(scales.to(self.compute_dtype))
|
|
181
|
+
self.weight_zeros.copy_(zeros.to(self.compute_dtype))
|
|
182
|
+
self._is_initialized = True
|
|
183
|
+
|
|
184
|
+
def get_weight(self) -> torch.Tensor:
|
|
185
|
+
"""Dequantize and return the full weight matrix."""
|
|
186
|
+
weight = dequantize_tensor_4bit(
|
|
187
|
+
self.weight_packed,
|
|
188
|
+
self.weight_scales,
|
|
189
|
+
self.weight_zeros,
|
|
190
|
+
self.group_size,
|
|
191
|
+
(self.out_features, self.in_features)
|
|
192
|
+
)
|
|
193
|
+
return weight.to(self.compute_dtype)
|
|
194
|
+
|
|
195
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
196
|
+
"""Forward pass with on-the-fly dequantization."""
|
|
197
|
+
weight = self.get_weight()
|
|
198
|
+
output = F.linear(x.to(self.compute_dtype), weight, self.bias)
|
|
199
|
+
return output
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class LoRALinear4bit(nn.Module):
|
|
203
|
+
"""
|
|
204
|
+
QLoRA-style 4-bit quantized linear with LoRA adapters.
|
|
205
|
+
|
|
206
|
+
Combines:
|
|
207
|
+
- 4-bit quantized base weights (frozen)
|
|
208
|
+
- Full-precision LoRA adapters (trainable)
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
in_features: int,
|
|
214
|
+
out_features: int,
|
|
215
|
+
rank: int = 8,
|
|
216
|
+
alpha: float = 16.0,
|
|
217
|
+
dropout: float = 0.1,
|
|
218
|
+
group_size: int = 64,
|
|
219
|
+
compute_dtype: torch.dtype = torch.float16
|
|
220
|
+
):
|
|
221
|
+
super().__init__()
|
|
222
|
+
self.in_features = in_features
|
|
223
|
+
self.out_features = out_features
|
|
224
|
+
self.rank = rank
|
|
225
|
+
self.alpha = alpha
|
|
226
|
+
self.scaling = alpha / rank
|
|
227
|
+
|
|
228
|
+
# Quantized base weight (frozen)
|
|
229
|
+
self.base = QuantizedLinear(
|
|
230
|
+
in_features, out_features,
|
|
231
|
+
bias=False,
|
|
232
|
+
group_size=group_size,
|
|
233
|
+
compute_dtype=compute_dtype
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# LoRA adapters (trainable, full precision)
|
|
237
|
+
self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
|
|
238
|
+
self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
|
|
239
|
+
self.dropout = nn.Dropout(dropout)
|
|
240
|
+
|
|
241
|
+
# Initialize LoRA weights
|
|
242
|
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
243
|
+
nn.init.zeros_(self.lora_B)
|
|
244
|
+
|
|
245
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
246
|
+
"""Forward pass: base output + LoRA adaptation."""
|
|
247
|
+
# Base forward (quantized)
|
|
248
|
+
base_output = self.base(x)
|
|
249
|
+
|
|
250
|
+
# LoRA forward (full precision)
|
|
251
|
+
lora_output = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
|
|
252
|
+
lora_output = lora_output * self.scaling
|
|
253
|
+
|
|
254
|
+
return base_output + lora_output.to(base_output.dtype)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# =============================================================================
|
|
258
|
+
# Rotary Position Embeddings (RoPE)
|
|
259
|
+
# =============================================================================
|
|
260
|
+
|
|
261
|
+
class RotaryPositionEmbedding(nn.Module):
|
|
262
|
+
"""
|
|
263
|
+
Rotary Position Embeddings (RoPE) for efficient position encoding.
|
|
264
|
+
|
|
265
|
+
RoPE encodes position information directly into the attention computation,
|
|
266
|
+
providing better extrapolation and computational efficiency.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
def __init__(
|
|
270
|
+
self,
|
|
271
|
+
dim: int,
|
|
272
|
+
max_seq_len: int = 2048,
|
|
273
|
+
base: int = 10000
|
|
274
|
+
):
|
|
275
|
+
super().__init__()
|
|
276
|
+
self.dim = dim
|
|
277
|
+
self.max_seq_len = max_seq_len
|
|
278
|
+
self.base = base
|
|
279
|
+
|
|
280
|
+
# Precompute frequencies
|
|
281
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
282
|
+
self.register_buffer('inv_freq', inv_freq)
|
|
283
|
+
|
|
284
|
+
# Precompute cos/sin for max sequence length
|
|
285
|
+
self._precompute_cos_sin(max_seq_len)
|
|
286
|
+
|
|
287
|
+
def _precompute_cos_sin(self, seq_len: int):
|
|
288
|
+
"""Precompute cos and sin for given sequence length."""
|
|
289
|
+
t = torch.arange(seq_len, device=self.inv_freq.device)
|
|
290
|
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
|
291
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
292
|
+
|
|
293
|
+
self.register_buffer('cos_cached', emb.cos())
|
|
294
|
+
self.register_buffer('sin_cached', emb.sin())
|
|
295
|
+
|
|
296
|
+
def forward(
|
|
297
|
+
self,
|
|
298
|
+
q: torch.Tensor,
|
|
299
|
+
k: torch.Tensor,
|
|
300
|
+
seq_len: int
|
|
301
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
302
|
+
"""
|
|
303
|
+
Apply rotary position embeddings to query and key tensors.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
q: Query tensor of shape (batch, heads, seq_len, head_dim)
|
|
307
|
+
k: Key tensor of shape (batch, heads, seq_len, head_dim)
|
|
308
|
+
seq_len: Sequence length
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Tuple of (q_rotated, k_rotated)
|
|
312
|
+
"""
|
|
313
|
+
# Extend cache if needed
|
|
314
|
+
if seq_len > self.max_seq_len:
|
|
315
|
+
self._precompute_cos_sin(seq_len)
|
|
316
|
+
self.max_seq_len = seq_len
|
|
317
|
+
|
|
318
|
+
cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
|
|
319
|
+
sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
|
|
320
|
+
|
|
321
|
+
q_rotated = self._apply_rotary(q, cos, sin)
|
|
322
|
+
k_rotated = self._apply_rotary(k, cos, sin)
|
|
323
|
+
|
|
324
|
+
return q_rotated, k_rotated
|
|
325
|
+
|
|
326
|
+
def _apply_rotary(
|
|
327
|
+
self,
|
|
328
|
+
x: torch.Tensor,
|
|
329
|
+
cos: torch.Tensor,
|
|
330
|
+
sin: torch.Tensor
|
|
331
|
+
) -> torch.Tensor:
|
|
332
|
+
"""Apply rotary embedding to a single tensor."""
|
|
333
|
+
# Split into two halves
|
|
334
|
+
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
|
|
335
|
+
|
|
336
|
+
# Rotate
|
|
337
|
+
rotated = torch.cat((-x2, x1), dim=-1)
|
|
338
|
+
|
|
339
|
+
# Apply rotation
|
|
340
|
+
return x * cos + rotated * sin
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def apply_rotary_pos_emb(
|
|
344
|
+
q: torch.Tensor,
|
|
345
|
+
k: torch.Tensor,
|
|
346
|
+
cos: torch.Tensor,
|
|
347
|
+
sin: torch.Tensor
|
|
348
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
349
|
+
"""
|
|
350
|
+
Functional interface for applying rotary position embeddings.
|
|
351
|
+
|
|
352
|
+
More efficient for models that precompute cos/sin.
|
|
353
|
+
"""
|
|
354
|
+
# Rotate query
|
|
355
|
+
q1, q2 = q[..., :q.shape[-1]//2], q[..., q.shape[-1]//2:]
|
|
356
|
+
q_rotated = torch.cat((-q2, q1), dim=-1)
|
|
357
|
+
q_out = q * cos + q_rotated * sin
|
|
358
|
+
|
|
359
|
+
# Rotate key
|
|
360
|
+
k1, k2 = k[..., :k.shape[-1]//2], k[..., k.shape[-1]//2:]
|
|
361
|
+
k_rotated = torch.cat((-k2, k1), dim=-1)
|
|
362
|
+
k_out = k * cos + k_rotated * sin
|
|
363
|
+
|
|
364
|
+
return q_out, k_out
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
# =============================================================================
|
|
368
|
+
# Memory-Efficient Attention
|
|
369
|
+
# =============================================================================
|
|
370
|
+
|
|
371
|
+
class MemoryEfficientAttention(nn.Module):
|
|
372
|
+
"""
|
|
373
|
+
Memory-efficient attention implementation.
|
|
374
|
+
|
|
375
|
+
Uses chunked computation to reduce peak memory usage,
|
|
376
|
+
with automatic fallback to flash attention when available.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(
|
|
380
|
+
self,
|
|
381
|
+
embed_dim: int,
|
|
382
|
+
num_heads: int,
|
|
383
|
+
dropout: float = 0.0,
|
|
384
|
+
use_flash: bool = True,
|
|
385
|
+
chunk_size: int = 1024
|
|
386
|
+
):
|
|
387
|
+
super().__init__()
|
|
388
|
+
self.embed_dim = embed_dim
|
|
389
|
+
self.num_heads = num_heads
|
|
390
|
+
self.head_dim = embed_dim // num_heads
|
|
391
|
+
self.scale = self.head_dim ** -0.5
|
|
392
|
+
self.dropout = dropout
|
|
393
|
+
self.chunk_size = chunk_size
|
|
394
|
+
|
|
395
|
+
# Use flash attention if available and requested
|
|
396
|
+
self.use_flash = use_flash and FLASH_ATTENTION_AVAILABLE
|
|
397
|
+
|
|
398
|
+
if self.use_flash:
|
|
399
|
+
logger.info("Using Flash Attention for memory-efficient computation")
|
|
400
|
+
else:
|
|
401
|
+
logger.info("Using chunked attention (Flash Attention not available)")
|
|
402
|
+
|
|
403
|
+
def forward(
|
|
404
|
+
self,
|
|
405
|
+
q: torch.Tensor,
|
|
406
|
+
k: torch.Tensor,
|
|
407
|
+
v: torch.Tensor,
|
|
408
|
+
mask: Optional[torch.Tensor] = None,
|
|
409
|
+
is_causal: bool = False
|
|
410
|
+
) -> torch.Tensor:
|
|
411
|
+
"""
|
|
412
|
+
Compute attention efficiently.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
q: Query tensor (batch, heads, seq_len, head_dim)
|
|
416
|
+
k: Key tensor (batch, heads, seq_len, head_dim)
|
|
417
|
+
v: Value tensor (batch, heads, seq_len, head_dim)
|
|
418
|
+
mask: Optional attention mask
|
|
419
|
+
is_causal: Whether to use causal masking
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
Attention output tensor
|
|
423
|
+
"""
|
|
424
|
+
if self.use_flash:
|
|
425
|
+
return self._flash_attention(q, k, v, is_causal)
|
|
426
|
+
else:
|
|
427
|
+
return self._chunked_attention(q, k, v, mask, is_causal)
|
|
428
|
+
|
|
429
|
+
def _flash_attention(
|
|
430
|
+
self,
|
|
431
|
+
q: torch.Tensor,
|
|
432
|
+
k: torch.Tensor,
|
|
433
|
+
v: torch.Tensor,
|
|
434
|
+
is_causal: bool
|
|
435
|
+
) -> torch.Tensor:
|
|
436
|
+
"""Use flash attention for efficient computation."""
|
|
437
|
+
# Flash attention expects (batch, seq_len, heads, head_dim)
|
|
438
|
+
q = q.transpose(1, 2)
|
|
439
|
+
k = k.transpose(1, 2)
|
|
440
|
+
v = v.transpose(1, 2)
|
|
441
|
+
|
|
442
|
+
out = flash_attn_func(
|
|
443
|
+
q, k, v,
|
|
444
|
+
dropout_p=self.dropout if self.training else 0.0,
|
|
445
|
+
causal=is_causal
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
return out.transpose(1, 2)
|
|
449
|
+
|
|
450
|
+
def _chunked_attention(
|
|
451
|
+
self,
|
|
452
|
+
q: torch.Tensor,
|
|
453
|
+
k: torch.Tensor,
|
|
454
|
+
v: torch.Tensor,
|
|
455
|
+
mask: Optional[torch.Tensor],
|
|
456
|
+
is_causal: bool
|
|
457
|
+
) -> torch.Tensor:
|
|
458
|
+
"""
|
|
459
|
+
Compute attention in chunks to reduce memory usage.
|
|
460
|
+
|
|
461
|
+
This is a fallback when flash attention is not available.
|
|
462
|
+
"""
|
|
463
|
+
batch, heads, seq_len, head_dim = q.shape
|
|
464
|
+
|
|
465
|
+
# For short sequences, use standard attention
|
|
466
|
+
if seq_len <= self.chunk_size:
|
|
467
|
+
return self._standard_attention(q, k, v, mask, is_causal)
|
|
468
|
+
|
|
469
|
+
# Chunked computation
|
|
470
|
+
output = torch.zeros_like(q)
|
|
471
|
+
|
|
472
|
+
for start in range(0, seq_len, self.chunk_size):
|
|
473
|
+
end = min(start + self.chunk_size, seq_len)
|
|
474
|
+
q_chunk = q[:, :, start:end, :]
|
|
475
|
+
|
|
476
|
+
# Compute attention scores for this chunk
|
|
477
|
+
attn = torch.matmul(q_chunk, k.transpose(-2, -1)) * self.scale
|
|
478
|
+
|
|
479
|
+
# Apply causal mask if needed
|
|
480
|
+
if is_causal:
|
|
481
|
+
causal_mask = torch.tril(
|
|
482
|
+
torch.ones(end - start, seq_len, device=q.device),
|
|
483
|
+
diagonal=start
|
|
484
|
+
)
|
|
485
|
+
attn = attn.masked_fill(causal_mask == 0, float('-inf'))
|
|
486
|
+
|
|
487
|
+
if mask is not None:
|
|
488
|
+
attn = attn + mask[:, :, start:end, :]
|
|
489
|
+
|
|
490
|
+
attn = F.softmax(attn, dim=-1)
|
|
491
|
+
|
|
492
|
+
if self.training and self.dropout > 0:
|
|
493
|
+
attn = F.dropout(attn, p=self.dropout)
|
|
494
|
+
|
|
495
|
+
output[:, :, start:end, :] = torch.matmul(attn, v)
|
|
496
|
+
|
|
497
|
+
return output
|
|
498
|
+
|
|
499
|
+
def _standard_attention(
|
|
500
|
+
self,
|
|
501
|
+
q: torch.Tensor,
|
|
502
|
+
k: torch.Tensor,
|
|
503
|
+
v: torch.Tensor,
|
|
504
|
+
mask: Optional[torch.Tensor],
|
|
505
|
+
is_causal: bool
|
|
506
|
+
) -> torch.Tensor:
|
|
507
|
+
"""Standard scaled dot-product attention."""
|
|
508
|
+
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
509
|
+
|
|
510
|
+
if is_causal:
|
|
511
|
+
seq_len = q.shape[2]
|
|
512
|
+
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=q.device))
|
|
513
|
+
attn = attn.masked_fill(causal_mask == 0, float('-inf'))
|
|
514
|
+
|
|
515
|
+
if mask is not None:
|
|
516
|
+
attn = attn + mask
|
|
517
|
+
|
|
518
|
+
attn = F.softmax(attn, dim=-1)
|
|
519
|
+
|
|
520
|
+
if self.training and self.dropout > 0:
|
|
521
|
+
attn = F.dropout(attn, p=self.dropout)
|
|
522
|
+
|
|
523
|
+
return torch.matmul(attn, v)
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
# =============================================================================
|
|
527
|
+
# Fused Cross-Entropy Loss
|
|
528
|
+
# =============================================================================
|
|
529
|
+
|
|
530
|
+
def fused_cross_entropy(
|
|
531
|
+
logits: torch.Tensor,
|
|
532
|
+
labels: torch.Tensor,
|
|
533
|
+
ignore_index: int = -100,
|
|
534
|
+
reduction: str = 'mean',
|
|
535
|
+
label_smoothing: float = 0.0,
|
|
536
|
+
chunk_size: int = 4096
|
|
537
|
+
) -> torch.Tensor:
|
|
538
|
+
"""
|
|
539
|
+
Memory-efficient cross-entropy loss.
|
|
540
|
+
|
|
541
|
+
Computes cross-entropy in chunks to avoid materializing the full
|
|
542
|
+
softmax matrix, significantly reducing memory usage for large vocabularies.
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
logits: Model outputs of shape (batch * seq_len, vocab_size)
|
|
546
|
+
labels: Target labels of shape (batch * seq_len,)
|
|
547
|
+
ignore_index: Label index to ignore
|
|
548
|
+
reduction: 'none', 'mean', or 'sum'
|
|
549
|
+
label_smoothing: Label smoothing factor
|
|
550
|
+
chunk_size: Chunk size for processing
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Cross-entropy loss
|
|
554
|
+
"""
|
|
555
|
+
batch_size = logits.shape[0]
|
|
556
|
+
vocab_size = logits.shape[1]
|
|
557
|
+
|
|
558
|
+
# For small batches, use standard cross-entropy
|
|
559
|
+
if batch_size <= chunk_size:
|
|
560
|
+
return F.cross_entropy(
|
|
561
|
+
logits, labels,
|
|
562
|
+
ignore_index=ignore_index,
|
|
563
|
+
reduction=reduction,
|
|
564
|
+
label_smoothing=label_smoothing
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Chunked computation for large batches
|
|
568
|
+
total_loss = 0.0
|
|
569
|
+
total_count = 0
|
|
570
|
+
|
|
571
|
+
for start in range(0, batch_size, chunk_size):
|
|
572
|
+
end = min(start + chunk_size, batch_size)
|
|
573
|
+
chunk_logits = logits[start:end]
|
|
574
|
+
chunk_labels = labels[start:end]
|
|
575
|
+
|
|
576
|
+
# Compute loss for this chunk
|
|
577
|
+
chunk_loss = F.cross_entropy(
|
|
578
|
+
chunk_logits, chunk_labels,
|
|
579
|
+
ignore_index=ignore_index,
|
|
580
|
+
reduction='sum',
|
|
581
|
+
label_smoothing=label_smoothing
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Count valid labels
|
|
585
|
+
valid_mask = chunk_labels != ignore_index
|
|
586
|
+
chunk_count = valid_mask.sum().item()
|
|
587
|
+
|
|
588
|
+
total_loss += chunk_loss
|
|
589
|
+
total_count += chunk_count
|
|
590
|
+
|
|
591
|
+
if reduction == 'none':
|
|
592
|
+
raise ValueError("reduction='none' not supported in chunked mode")
|
|
593
|
+
elif reduction == 'sum':
|
|
594
|
+
return total_loss
|
|
595
|
+
else: # mean
|
|
596
|
+
return total_loss / max(total_count, 1)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
# =============================================================================
|
|
600
|
+
# Gradient Checkpointing Utilities
|
|
601
|
+
# =============================================================================
|
|
602
|
+
|
|
603
|
+
class GradientCheckpointFunction(torch.autograd.Function):
|
|
604
|
+
"""
|
|
605
|
+
Custom gradient checkpointing function for transformer layers.
|
|
606
|
+
|
|
607
|
+
More efficient than torch.utils.checkpoint by avoiding
|
|
608
|
+
redundant context saves.
|
|
609
|
+
"""
|
|
610
|
+
|
|
611
|
+
@staticmethod
|
|
612
|
+
def forward(ctx, run_function, preserve_rng_state, *args):
|
|
613
|
+
ctx.run_function = run_function
|
|
614
|
+
ctx.preserve_rng_state = preserve_rng_state
|
|
615
|
+
|
|
616
|
+
# Save RNG state if needed
|
|
617
|
+
if preserve_rng_state:
|
|
618
|
+
ctx.fwd_cpu_state = torch.get_rng_state()
|
|
619
|
+
ctx.had_cuda_in_fwd = False
|
|
620
|
+
if torch.cuda.is_available():
|
|
621
|
+
ctx.had_cuda_in_fwd = True
|
|
622
|
+
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
|
|
623
|
+
|
|
624
|
+
# Save input tensors for backward
|
|
625
|
+
ctx.save_for_backward(*args)
|
|
626
|
+
|
|
627
|
+
with torch.no_grad():
|
|
628
|
+
outputs = run_function(*args)
|
|
629
|
+
|
|
630
|
+
return outputs
|
|
631
|
+
|
|
632
|
+
@staticmethod
|
|
633
|
+
def backward(ctx, *output_grads):
|
|
634
|
+
if not torch.autograd._is_checkpoint_valid():
|
|
635
|
+
raise RuntimeError(
|
|
636
|
+
"Checkpointing is not compatible with .grad()"
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
inputs = ctx.saved_tensors
|
|
640
|
+
|
|
641
|
+
# Restore RNG state
|
|
642
|
+
if ctx.preserve_rng_state:
|
|
643
|
+
rng_devices = []
|
|
644
|
+
if ctx.had_cuda_in_fwd:
|
|
645
|
+
rng_devices = ctx.fwd_gpu_devices
|
|
646
|
+
|
|
647
|
+
with torch.enable_grad():
|
|
648
|
+
# Recompute forward pass
|
|
649
|
+
detached_inputs = [x.detach().requires_grad_(x.requires_grad) for x in inputs]
|
|
650
|
+
outputs = ctx.run_function(*detached_inputs)
|
|
651
|
+
|
|
652
|
+
# Compute gradients
|
|
653
|
+
if isinstance(outputs, torch.Tensor):
|
|
654
|
+
outputs = (outputs,)
|
|
655
|
+
|
|
656
|
+
grads = torch.autograd.grad(
|
|
657
|
+
outputs,
|
|
658
|
+
[x for x in detached_inputs if x.requires_grad],
|
|
659
|
+
output_grads,
|
|
660
|
+
allow_unused=True
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
# Match gradients to inputs
|
|
664
|
+
grad_iter = iter(grads)
|
|
665
|
+
input_grads = []
|
|
666
|
+
for x in detached_inputs:
|
|
667
|
+
if x.requires_grad:
|
|
668
|
+
input_grads.append(next(grad_iter))
|
|
669
|
+
else:
|
|
670
|
+
input_grads.append(None)
|
|
671
|
+
|
|
672
|
+
return (None, None) + tuple(input_grads)
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def get_device_states(*args):
|
|
676
|
+
"""Get CUDA RNG state for gradient checkpointing."""
|
|
677
|
+
fwd_gpu_devices = []
|
|
678
|
+
fwd_gpu_states = []
|
|
679
|
+
|
|
680
|
+
for arg in args:
|
|
681
|
+
if isinstance(arg, torch.Tensor) and arg.is_cuda:
|
|
682
|
+
device = arg.device
|
|
683
|
+
if device not in fwd_gpu_devices:
|
|
684
|
+
fwd_gpu_devices.append(device)
|
|
685
|
+
fwd_gpu_states.append(torch.cuda.get_rng_state(device))
|
|
686
|
+
|
|
687
|
+
return fwd_gpu_devices, fwd_gpu_states
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def checkpoint(function, *args, preserve_rng_state: bool = True):
|
|
691
|
+
"""
|
|
692
|
+
Apply gradient checkpointing to a function.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
function: Function to checkpoint
|
|
696
|
+
*args: Arguments to the function
|
|
697
|
+
preserve_rng_state: Whether to preserve RNG state
|
|
698
|
+
|
|
699
|
+
Returns:
|
|
700
|
+
Function output with gradient checkpointing applied
|
|
701
|
+
"""
|
|
702
|
+
return GradientCheckpointFunction.apply(function, preserve_rng_state, *args)
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
# =============================================================================
|
|
706
|
+
# Mixed Precision Training Utilities
|
|
707
|
+
# =============================================================================
|
|
708
|
+
|
|
709
|
+
class MixedPrecisionTrainer:
|
|
710
|
+
"""
|
|
711
|
+
Mixed precision training utilities.
|
|
712
|
+
|
|
713
|
+
Provides automatic mixed precision (AMP) with gradient scaling
|
|
714
|
+
for stable training.
|
|
715
|
+
"""
|
|
716
|
+
|
|
717
|
+
def __init__(
|
|
718
|
+
self,
|
|
719
|
+
enabled: bool = True,
|
|
720
|
+
dtype: torch.dtype = torch.float16,
|
|
721
|
+
init_scale: float = 65536.0,
|
|
722
|
+
growth_factor: float = 2.0,
|
|
723
|
+
backoff_factor: float = 0.5,
|
|
724
|
+
growth_interval: int = 2000
|
|
725
|
+
):
|
|
726
|
+
self.enabled = enabled and torch.cuda.is_available()
|
|
727
|
+
self.dtype = dtype
|
|
728
|
+
|
|
729
|
+
if self.enabled:
|
|
730
|
+
self.scaler = GradScaler(
|
|
731
|
+
init_scale=init_scale,
|
|
732
|
+
growth_factor=growth_factor,
|
|
733
|
+
backoff_factor=backoff_factor,
|
|
734
|
+
growth_interval=growth_interval
|
|
735
|
+
)
|
|
736
|
+
logger.info(f"Mixed precision training enabled with {dtype}")
|
|
737
|
+
else:
|
|
738
|
+
self.scaler = None
|
|
739
|
+
|
|
740
|
+
@property
|
|
741
|
+
def autocast_context(self):
|
|
742
|
+
"""Get autocast context manager."""
|
|
743
|
+
return autocast(enabled=self.enabled, dtype=self.dtype)
|
|
744
|
+
|
|
745
|
+
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
|
|
746
|
+
"""Scale loss for gradient stability."""
|
|
747
|
+
if self.scaler is not None:
|
|
748
|
+
return self.scaler.scale(loss)
|
|
749
|
+
return loss
|
|
750
|
+
|
|
751
|
+
def unscale_gradients(self, optimizer):
|
|
752
|
+
"""Unscale gradients before clipping."""
|
|
753
|
+
if self.scaler is not None:
|
|
754
|
+
self.scaler.unscale_(optimizer)
|
|
755
|
+
|
|
756
|
+
def step(self, optimizer):
|
|
757
|
+
"""Take optimizer step with gradient scaling."""
|
|
758
|
+
if self.scaler is not None:
|
|
759
|
+
self.scaler.step(optimizer)
|
|
760
|
+
self.scaler.update()
|
|
761
|
+
else:
|
|
762
|
+
optimizer.step()
|
|
763
|
+
|
|
764
|
+
def get_scale(self) -> float:
|
|
765
|
+
"""Get current gradient scale."""
|
|
766
|
+
if self.scaler is not None:
|
|
767
|
+
return self.scaler.get_scale()
|
|
768
|
+
return 1.0
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
# =============================================================================
|
|
772
|
+
# Memory Monitoring
|
|
773
|
+
# =============================================================================
|
|
774
|
+
|
|
775
|
+
def get_memory_stats() -> Dict[str, float]:
|
|
776
|
+
"""
|
|
777
|
+
Get GPU memory statistics.
|
|
778
|
+
|
|
779
|
+
Returns:
|
|
780
|
+
Dictionary with memory stats in GB
|
|
781
|
+
"""
|
|
782
|
+
if not torch.cuda.is_available():
|
|
783
|
+
return {}
|
|
784
|
+
|
|
785
|
+
return {
|
|
786
|
+
'allocated': torch.cuda.memory_allocated() / 1e9,
|
|
787
|
+
'reserved': torch.cuda.memory_reserved() / 1e9,
|
|
788
|
+
'max_allocated': torch.cuda.max_memory_allocated() / 1e9,
|
|
789
|
+
'max_reserved': torch.cuda.max_memory_reserved() / 1e9
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
def reset_memory_stats():
|
|
794
|
+
"""Reset GPU memory statistics."""
|
|
795
|
+
if torch.cuda.is_available():
|
|
796
|
+
torch.cuda.reset_peak_memory_stats()
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def cleanup_memory():
|
|
800
|
+
"""Free unused GPU memory."""
|
|
801
|
+
if torch.cuda.is_available():
|
|
802
|
+
torch.cuda.empty_cache()
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
def log_memory_usage(prefix: str = ""):
|
|
806
|
+
"""Log current GPU memory usage."""
|
|
807
|
+
stats = get_memory_stats()
|
|
808
|
+
if stats:
|
|
809
|
+
logger.info(
|
|
810
|
+
f"{prefix}Memory: {stats['allocated']:.2f}GB allocated, "
|
|
811
|
+
f"{stats['max_allocated']:.2f}GB peak"
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
# =============================================================================
|
|
816
|
+
# Optimization Config
|
|
817
|
+
# =============================================================================
|
|
818
|
+
|
|
819
|
+
class OptimizationConfig:
|
|
820
|
+
"""Configuration for optimization settings."""
|
|
821
|
+
|
|
822
|
+
def __init__(
|
|
823
|
+
self,
|
|
824
|
+
use_4bit: bool = False,
|
|
825
|
+
use_8bit: bool = False,
|
|
826
|
+
use_flash_attention: bool = True,
|
|
827
|
+
use_gradient_checkpointing: bool = True,
|
|
828
|
+
use_fused_ops: bool = True,
|
|
829
|
+
use_rope: bool = True,
|
|
830
|
+
gradient_accumulation_steps: int = 4,
|
|
831
|
+
mixed_precision: str = "fp16", # fp16, bf16, or fp32
|
|
832
|
+
group_size: int = 64,
|
|
833
|
+
chunk_size: int = 1024
|
|
834
|
+
):
|
|
835
|
+
self.use_4bit = use_4bit
|
|
836
|
+
self.use_8bit = use_8bit
|
|
837
|
+
self.use_flash_attention = use_flash_attention
|
|
838
|
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
|
839
|
+
self.use_fused_ops = use_fused_ops
|
|
840
|
+
self.use_rope = use_rope
|
|
841
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
842
|
+
self.mixed_precision = mixed_precision
|
|
843
|
+
self.group_size = group_size
|
|
844
|
+
self.chunk_size = chunk_size
|
|
845
|
+
|
|
846
|
+
# Validate
|
|
847
|
+
if use_4bit and use_8bit:
|
|
848
|
+
raise ValueError("Cannot use both 4-bit and 8-bit quantization")
|
|
849
|
+
|
|
850
|
+
if mixed_precision not in ["fp16", "bf16", "fp32"]:
|
|
851
|
+
raise ValueError(f"Invalid mixed_precision: {mixed_precision}")
|
|
852
|
+
|
|
853
|
+
@property
|
|
854
|
+
def compute_dtype(self) -> torch.dtype:
|
|
855
|
+
"""Get compute dtype based on config."""
|
|
856
|
+
if self.mixed_precision == "bf16":
|
|
857
|
+
return torch.bfloat16
|
|
858
|
+
elif self.mixed_precision == "fp16":
|
|
859
|
+
return torch.float16
|
|
860
|
+
else:
|
|
861
|
+
return torch.float32
|
|
862
|
+
|
|
863
|
+
def __repr__(self):
|
|
864
|
+
return (
|
|
865
|
+
f"OptimizationConfig("
|
|
866
|
+
f"use_4bit={self.use_4bit}, "
|
|
867
|
+
f"use_flash_attention={self.use_flash_attention}, "
|
|
868
|
+
f"use_gradient_checkpointing={self.use_gradient_checkpointing}, "
|
|
869
|
+
f"mixed_precision={self.mixed_precision})"
|
|
870
|
+
)
|