langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,870 @@
1
+ """
2
+ optimizations.py: Efficient fine-tuning optimizations for Langtune
3
+
4
+ This module implements memory-efficient and speed-optimized techniques
5
+ inspired by Unsloth, including:
6
+ - 4-bit quantization (QLoRA style)
7
+ - Rotary Position Embeddings (RoPE)
8
+ - Fused cross-entropy loss
9
+ - Memory-efficient attention
10
+ - Gradient checkpointing utilities
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.cuda.amp import autocast, GradScaler
18
+ from typing import Optional, Tuple, Dict, Any, Union
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Check for available optimizations
24
+ FLASH_ATTENTION_AVAILABLE = False
25
+ try:
26
+ from flash_attn import flash_attn_func
27
+ FLASH_ATTENTION_AVAILABLE = True
28
+ except ImportError:
29
+ pass
30
+
31
+ # Check for bitsandbytes (optional 4-bit support)
32
+ BITSANDBYTES_AVAILABLE = False
33
+ try:
34
+ import bitsandbytes as bnb
35
+ BITSANDBYTES_AVAILABLE = True
36
+ except ImportError:
37
+ pass
38
+
39
+
40
+ # =============================================================================
41
+ # Quantization Utilities
42
+ # =============================================================================
43
+
44
+ def quantize_tensor_4bit(
45
+ tensor: torch.Tensor,
46
+ group_size: int = 64
47
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
48
+ """
49
+ Quantize a tensor to 4-bit representation.
50
+
51
+ Args:
52
+ tensor: Input tensor to quantize
53
+ group_size: Number of elements per quantization group
54
+
55
+ Returns:
56
+ Tuple of (quantized_data, scales, zero_points)
57
+ """
58
+ original_shape = tensor.shape
59
+ tensor = tensor.reshape(-1, group_size)
60
+
61
+ # Compute min/max per group
62
+ min_vals = tensor.min(dim=1, keepdim=True)[0]
63
+ max_vals = tensor.max(dim=1, keepdim=True)[0]
64
+
65
+ # Compute scale and zero point
66
+ scales = (max_vals - min_vals) / 15.0 # 4-bit = 16 levels
67
+ zero_points = min_vals
68
+
69
+ # Quantize
70
+ quantized = torch.clamp(
71
+ torch.round((tensor - zero_points) / (scales + 1e-8)),
72
+ 0, 15
73
+ ).to(torch.uint8)
74
+
75
+ # Pack two 4-bit values into one uint8
76
+ packed = quantized[:, ::2] | (quantized[:, 1::2] << 4)
77
+
78
+ return packed, scales.squeeze(-1), zero_points.squeeze(-1)
79
+
80
+
81
+ def dequantize_tensor_4bit(
82
+ packed: torch.Tensor,
83
+ scales: torch.Tensor,
84
+ zero_points: torch.Tensor,
85
+ group_size: int = 64,
86
+ output_shape: Tuple[int, ...] = None
87
+ ) -> torch.Tensor:
88
+ """
89
+ Dequantize a 4-bit tensor back to float.
90
+
91
+ Args:
92
+ packed: Packed 4-bit tensor
93
+ scales: Quantization scales
94
+ zero_points: Quantization zero points
95
+ group_size: Number of elements per group
96
+ output_shape: Original tensor shape
97
+
98
+ Returns:
99
+ Dequantized float tensor
100
+ """
101
+ # Unpack 4-bit values
102
+ low = packed & 0x0F
103
+ high = (packed >> 4) & 0x0F
104
+
105
+ # Interleave to get original order
106
+ batch_size = packed.shape[0]
107
+ unpacked = torch.zeros(batch_size, group_size, device=packed.device, dtype=torch.float32)
108
+ unpacked[:, ::2] = low.float()
109
+ unpacked[:, 1::2] = high.float()
110
+
111
+ # Dequantize
112
+ dequantized = unpacked * scales.unsqueeze(-1) + zero_points.unsqueeze(-1)
113
+
114
+ if output_shape is not None:
115
+ dequantized = dequantized.reshape(output_shape)
116
+
117
+ return dequantized
118
+
119
+
120
+ class QuantizedLinear(nn.Module):
121
+ """
122
+ 4-bit quantized linear layer with efficient on-the-fly dequantization.
123
+
124
+ This provides significant memory savings (75%) compared to FP16,
125
+ with minimal accuracy loss.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ in_features: int,
131
+ out_features: int,
132
+ bias: bool = True,
133
+ group_size: int = 64,
134
+ compute_dtype: torch.dtype = torch.float16
135
+ ):
136
+ super().__init__()
137
+ self.in_features = in_features
138
+ self.out_features = out_features
139
+ self.group_size = group_size
140
+ self.compute_dtype = compute_dtype
141
+
142
+ # Ensure dimensions are compatible with group size
143
+ assert in_features % group_size == 0, f"in_features ({in_features}) must be divisible by group_size ({group_size})"
144
+
145
+ # Number of groups
146
+ self.num_groups = in_features // group_size
147
+
148
+ # Quantized weight storage (packed 4-bit)
149
+ self.register_buffer(
150
+ 'weight_packed',
151
+ torch.zeros(out_features * self.num_groups, group_size // 2, dtype=torch.uint8)
152
+ )
153
+ self.register_buffer(
154
+ 'weight_scales',
155
+ torch.ones(out_features * self.num_groups, dtype=compute_dtype)
156
+ )
157
+ self.register_buffer(
158
+ 'weight_zeros',
159
+ torch.zeros(out_features * self.num_groups, dtype=compute_dtype)
160
+ )
161
+
162
+ if bias:
163
+ self.bias = nn.Parameter(torch.zeros(out_features, dtype=compute_dtype))
164
+ else:
165
+ self.register_parameter('bias', None)
166
+
167
+ self._is_initialized = False
168
+
169
+ def initialize_from_weight(self, weight: torch.Tensor):
170
+ """Initialize quantized weights from a float weight tensor."""
171
+ assert weight.shape == (self.out_features, self.in_features)
172
+
173
+ # Reshape for group-wise quantization
174
+ weight_grouped = weight.reshape(self.out_features * self.num_groups, self.group_size)
175
+
176
+ # Quantize
177
+ packed, scales, zeros = quantize_tensor_4bit(weight_grouped, self.group_size)
178
+
179
+ self.weight_packed.copy_(packed)
180
+ self.weight_scales.copy_(scales.to(self.compute_dtype))
181
+ self.weight_zeros.copy_(zeros.to(self.compute_dtype))
182
+ self._is_initialized = True
183
+
184
+ def get_weight(self) -> torch.Tensor:
185
+ """Dequantize and return the full weight matrix."""
186
+ weight = dequantize_tensor_4bit(
187
+ self.weight_packed,
188
+ self.weight_scales,
189
+ self.weight_zeros,
190
+ self.group_size,
191
+ (self.out_features, self.in_features)
192
+ )
193
+ return weight.to(self.compute_dtype)
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ """Forward pass with on-the-fly dequantization."""
197
+ weight = self.get_weight()
198
+ output = F.linear(x.to(self.compute_dtype), weight, self.bias)
199
+ return output
200
+
201
+
202
+ class LoRALinear4bit(nn.Module):
203
+ """
204
+ QLoRA-style 4-bit quantized linear with LoRA adapters.
205
+
206
+ Combines:
207
+ - 4-bit quantized base weights (frozen)
208
+ - Full-precision LoRA adapters (trainable)
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ in_features: int,
214
+ out_features: int,
215
+ rank: int = 8,
216
+ alpha: float = 16.0,
217
+ dropout: float = 0.1,
218
+ group_size: int = 64,
219
+ compute_dtype: torch.dtype = torch.float16
220
+ ):
221
+ super().__init__()
222
+ self.in_features = in_features
223
+ self.out_features = out_features
224
+ self.rank = rank
225
+ self.alpha = alpha
226
+ self.scaling = alpha / rank
227
+
228
+ # Quantized base weight (frozen)
229
+ self.base = QuantizedLinear(
230
+ in_features, out_features,
231
+ bias=False,
232
+ group_size=group_size,
233
+ compute_dtype=compute_dtype
234
+ )
235
+
236
+ # LoRA adapters (trainable, full precision)
237
+ self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
238
+ self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
239
+ self.dropout = nn.Dropout(dropout)
240
+
241
+ # Initialize LoRA weights
242
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
243
+ nn.init.zeros_(self.lora_B)
244
+
245
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
246
+ """Forward pass: base output + LoRA adaptation."""
247
+ # Base forward (quantized)
248
+ base_output = self.base(x)
249
+
250
+ # LoRA forward (full precision)
251
+ lora_output = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
252
+ lora_output = lora_output * self.scaling
253
+
254
+ return base_output + lora_output.to(base_output.dtype)
255
+
256
+
257
+ # =============================================================================
258
+ # Rotary Position Embeddings (RoPE)
259
+ # =============================================================================
260
+
261
+ class RotaryPositionEmbedding(nn.Module):
262
+ """
263
+ Rotary Position Embeddings (RoPE) for efficient position encoding.
264
+
265
+ RoPE encodes position information directly into the attention computation,
266
+ providing better extrapolation and computational efficiency.
267
+ """
268
+
269
+ def __init__(
270
+ self,
271
+ dim: int,
272
+ max_seq_len: int = 2048,
273
+ base: int = 10000
274
+ ):
275
+ super().__init__()
276
+ self.dim = dim
277
+ self.max_seq_len = max_seq_len
278
+ self.base = base
279
+
280
+ # Precompute frequencies
281
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
282
+ self.register_buffer('inv_freq', inv_freq)
283
+
284
+ # Precompute cos/sin for max sequence length
285
+ self._precompute_cos_sin(max_seq_len)
286
+
287
+ def _precompute_cos_sin(self, seq_len: int):
288
+ """Precompute cos and sin for given sequence length."""
289
+ t = torch.arange(seq_len, device=self.inv_freq.device)
290
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
291
+ emb = torch.cat((freqs, freqs), dim=-1)
292
+
293
+ self.register_buffer('cos_cached', emb.cos())
294
+ self.register_buffer('sin_cached', emb.sin())
295
+
296
+ def forward(
297
+ self,
298
+ q: torch.Tensor,
299
+ k: torch.Tensor,
300
+ seq_len: int
301
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
302
+ """
303
+ Apply rotary position embeddings to query and key tensors.
304
+
305
+ Args:
306
+ q: Query tensor of shape (batch, heads, seq_len, head_dim)
307
+ k: Key tensor of shape (batch, heads, seq_len, head_dim)
308
+ seq_len: Sequence length
309
+
310
+ Returns:
311
+ Tuple of (q_rotated, k_rotated)
312
+ """
313
+ # Extend cache if needed
314
+ if seq_len > self.max_seq_len:
315
+ self._precompute_cos_sin(seq_len)
316
+ self.max_seq_len = seq_len
317
+
318
+ cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
319
+ sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
320
+
321
+ q_rotated = self._apply_rotary(q, cos, sin)
322
+ k_rotated = self._apply_rotary(k, cos, sin)
323
+
324
+ return q_rotated, k_rotated
325
+
326
+ def _apply_rotary(
327
+ self,
328
+ x: torch.Tensor,
329
+ cos: torch.Tensor,
330
+ sin: torch.Tensor
331
+ ) -> torch.Tensor:
332
+ """Apply rotary embedding to a single tensor."""
333
+ # Split into two halves
334
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
335
+
336
+ # Rotate
337
+ rotated = torch.cat((-x2, x1), dim=-1)
338
+
339
+ # Apply rotation
340
+ return x * cos + rotated * sin
341
+
342
+
343
+ def apply_rotary_pos_emb(
344
+ q: torch.Tensor,
345
+ k: torch.Tensor,
346
+ cos: torch.Tensor,
347
+ sin: torch.Tensor
348
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
349
+ """
350
+ Functional interface for applying rotary position embeddings.
351
+
352
+ More efficient for models that precompute cos/sin.
353
+ """
354
+ # Rotate query
355
+ q1, q2 = q[..., :q.shape[-1]//2], q[..., q.shape[-1]//2:]
356
+ q_rotated = torch.cat((-q2, q1), dim=-1)
357
+ q_out = q * cos + q_rotated * sin
358
+
359
+ # Rotate key
360
+ k1, k2 = k[..., :k.shape[-1]//2], k[..., k.shape[-1]//2:]
361
+ k_rotated = torch.cat((-k2, k1), dim=-1)
362
+ k_out = k * cos + k_rotated * sin
363
+
364
+ return q_out, k_out
365
+
366
+
367
+ # =============================================================================
368
+ # Memory-Efficient Attention
369
+ # =============================================================================
370
+
371
+ class MemoryEfficientAttention(nn.Module):
372
+ """
373
+ Memory-efficient attention implementation.
374
+
375
+ Uses chunked computation to reduce peak memory usage,
376
+ with automatic fallback to flash attention when available.
377
+ """
378
+
379
+ def __init__(
380
+ self,
381
+ embed_dim: int,
382
+ num_heads: int,
383
+ dropout: float = 0.0,
384
+ use_flash: bool = True,
385
+ chunk_size: int = 1024
386
+ ):
387
+ super().__init__()
388
+ self.embed_dim = embed_dim
389
+ self.num_heads = num_heads
390
+ self.head_dim = embed_dim // num_heads
391
+ self.scale = self.head_dim ** -0.5
392
+ self.dropout = dropout
393
+ self.chunk_size = chunk_size
394
+
395
+ # Use flash attention if available and requested
396
+ self.use_flash = use_flash and FLASH_ATTENTION_AVAILABLE
397
+
398
+ if self.use_flash:
399
+ logger.info("Using Flash Attention for memory-efficient computation")
400
+ else:
401
+ logger.info("Using chunked attention (Flash Attention not available)")
402
+
403
+ def forward(
404
+ self,
405
+ q: torch.Tensor,
406
+ k: torch.Tensor,
407
+ v: torch.Tensor,
408
+ mask: Optional[torch.Tensor] = None,
409
+ is_causal: bool = False
410
+ ) -> torch.Tensor:
411
+ """
412
+ Compute attention efficiently.
413
+
414
+ Args:
415
+ q: Query tensor (batch, heads, seq_len, head_dim)
416
+ k: Key tensor (batch, heads, seq_len, head_dim)
417
+ v: Value tensor (batch, heads, seq_len, head_dim)
418
+ mask: Optional attention mask
419
+ is_causal: Whether to use causal masking
420
+
421
+ Returns:
422
+ Attention output tensor
423
+ """
424
+ if self.use_flash:
425
+ return self._flash_attention(q, k, v, is_causal)
426
+ else:
427
+ return self._chunked_attention(q, k, v, mask, is_causal)
428
+
429
+ def _flash_attention(
430
+ self,
431
+ q: torch.Tensor,
432
+ k: torch.Tensor,
433
+ v: torch.Tensor,
434
+ is_causal: bool
435
+ ) -> torch.Tensor:
436
+ """Use flash attention for efficient computation."""
437
+ # Flash attention expects (batch, seq_len, heads, head_dim)
438
+ q = q.transpose(1, 2)
439
+ k = k.transpose(1, 2)
440
+ v = v.transpose(1, 2)
441
+
442
+ out = flash_attn_func(
443
+ q, k, v,
444
+ dropout_p=self.dropout if self.training else 0.0,
445
+ causal=is_causal
446
+ )
447
+
448
+ return out.transpose(1, 2)
449
+
450
+ def _chunked_attention(
451
+ self,
452
+ q: torch.Tensor,
453
+ k: torch.Tensor,
454
+ v: torch.Tensor,
455
+ mask: Optional[torch.Tensor],
456
+ is_causal: bool
457
+ ) -> torch.Tensor:
458
+ """
459
+ Compute attention in chunks to reduce memory usage.
460
+
461
+ This is a fallback when flash attention is not available.
462
+ """
463
+ batch, heads, seq_len, head_dim = q.shape
464
+
465
+ # For short sequences, use standard attention
466
+ if seq_len <= self.chunk_size:
467
+ return self._standard_attention(q, k, v, mask, is_causal)
468
+
469
+ # Chunked computation
470
+ output = torch.zeros_like(q)
471
+
472
+ for start in range(0, seq_len, self.chunk_size):
473
+ end = min(start + self.chunk_size, seq_len)
474
+ q_chunk = q[:, :, start:end, :]
475
+
476
+ # Compute attention scores for this chunk
477
+ attn = torch.matmul(q_chunk, k.transpose(-2, -1)) * self.scale
478
+
479
+ # Apply causal mask if needed
480
+ if is_causal:
481
+ causal_mask = torch.tril(
482
+ torch.ones(end - start, seq_len, device=q.device),
483
+ diagonal=start
484
+ )
485
+ attn = attn.masked_fill(causal_mask == 0, float('-inf'))
486
+
487
+ if mask is not None:
488
+ attn = attn + mask[:, :, start:end, :]
489
+
490
+ attn = F.softmax(attn, dim=-1)
491
+
492
+ if self.training and self.dropout > 0:
493
+ attn = F.dropout(attn, p=self.dropout)
494
+
495
+ output[:, :, start:end, :] = torch.matmul(attn, v)
496
+
497
+ return output
498
+
499
+ def _standard_attention(
500
+ self,
501
+ q: torch.Tensor,
502
+ k: torch.Tensor,
503
+ v: torch.Tensor,
504
+ mask: Optional[torch.Tensor],
505
+ is_causal: bool
506
+ ) -> torch.Tensor:
507
+ """Standard scaled dot-product attention."""
508
+ attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
509
+
510
+ if is_causal:
511
+ seq_len = q.shape[2]
512
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=q.device))
513
+ attn = attn.masked_fill(causal_mask == 0, float('-inf'))
514
+
515
+ if mask is not None:
516
+ attn = attn + mask
517
+
518
+ attn = F.softmax(attn, dim=-1)
519
+
520
+ if self.training and self.dropout > 0:
521
+ attn = F.dropout(attn, p=self.dropout)
522
+
523
+ return torch.matmul(attn, v)
524
+
525
+
526
+ # =============================================================================
527
+ # Fused Cross-Entropy Loss
528
+ # =============================================================================
529
+
530
+ def fused_cross_entropy(
531
+ logits: torch.Tensor,
532
+ labels: torch.Tensor,
533
+ ignore_index: int = -100,
534
+ reduction: str = 'mean',
535
+ label_smoothing: float = 0.0,
536
+ chunk_size: int = 4096
537
+ ) -> torch.Tensor:
538
+ """
539
+ Memory-efficient cross-entropy loss.
540
+
541
+ Computes cross-entropy in chunks to avoid materializing the full
542
+ softmax matrix, significantly reducing memory usage for large vocabularies.
543
+
544
+ Args:
545
+ logits: Model outputs of shape (batch * seq_len, vocab_size)
546
+ labels: Target labels of shape (batch * seq_len,)
547
+ ignore_index: Label index to ignore
548
+ reduction: 'none', 'mean', or 'sum'
549
+ label_smoothing: Label smoothing factor
550
+ chunk_size: Chunk size for processing
551
+
552
+ Returns:
553
+ Cross-entropy loss
554
+ """
555
+ batch_size = logits.shape[0]
556
+ vocab_size = logits.shape[1]
557
+
558
+ # For small batches, use standard cross-entropy
559
+ if batch_size <= chunk_size:
560
+ return F.cross_entropy(
561
+ logits, labels,
562
+ ignore_index=ignore_index,
563
+ reduction=reduction,
564
+ label_smoothing=label_smoothing
565
+ )
566
+
567
+ # Chunked computation for large batches
568
+ total_loss = 0.0
569
+ total_count = 0
570
+
571
+ for start in range(0, batch_size, chunk_size):
572
+ end = min(start + chunk_size, batch_size)
573
+ chunk_logits = logits[start:end]
574
+ chunk_labels = labels[start:end]
575
+
576
+ # Compute loss for this chunk
577
+ chunk_loss = F.cross_entropy(
578
+ chunk_logits, chunk_labels,
579
+ ignore_index=ignore_index,
580
+ reduction='sum',
581
+ label_smoothing=label_smoothing
582
+ )
583
+
584
+ # Count valid labels
585
+ valid_mask = chunk_labels != ignore_index
586
+ chunk_count = valid_mask.sum().item()
587
+
588
+ total_loss += chunk_loss
589
+ total_count += chunk_count
590
+
591
+ if reduction == 'none':
592
+ raise ValueError("reduction='none' not supported in chunked mode")
593
+ elif reduction == 'sum':
594
+ return total_loss
595
+ else: # mean
596
+ return total_loss / max(total_count, 1)
597
+
598
+
599
+ # =============================================================================
600
+ # Gradient Checkpointing Utilities
601
+ # =============================================================================
602
+
603
+ class GradientCheckpointFunction(torch.autograd.Function):
604
+ """
605
+ Custom gradient checkpointing function for transformer layers.
606
+
607
+ More efficient than torch.utils.checkpoint by avoiding
608
+ redundant context saves.
609
+ """
610
+
611
+ @staticmethod
612
+ def forward(ctx, run_function, preserve_rng_state, *args):
613
+ ctx.run_function = run_function
614
+ ctx.preserve_rng_state = preserve_rng_state
615
+
616
+ # Save RNG state if needed
617
+ if preserve_rng_state:
618
+ ctx.fwd_cpu_state = torch.get_rng_state()
619
+ ctx.had_cuda_in_fwd = False
620
+ if torch.cuda.is_available():
621
+ ctx.had_cuda_in_fwd = True
622
+ ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
623
+
624
+ # Save input tensors for backward
625
+ ctx.save_for_backward(*args)
626
+
627
+ with torch.no_grad():
628
+ outputs = run_function(*args)
629
+
630
+ return outputs
631
+
632
+ @staticmethod
633
+ def backward(ctx, *output_grads):
634
+ if not torch.autograd._is_checkpoint_valid():
635
+ raise RuntimeError(
636
+ "Checkpointing is not compatible with .grad()"
637
+ )
638
+
639
+ inputs = ctx.saved_tensors
640
+
641
+ # Restore RNG state
642
+ if ctx.preserve_rng_state:
643
+ rng_devices = []
644
+ if ctx.had_cuda_in_fwd:
645
+ rng_devices = ctx.fwd_gpu_devices
646
+
647
+ with torch.enable_grad():
648
+ # Recompute forward pass
649
+ detached_inputs = [x.detach().requires_grad_(x.requires_grad) for x in inputs]
650
+ outputs = ctx.run_function(*detached_inputs)
651
+
652
+ # Compute gradients
653
+ if isinstance(outputs, torch.Tensor):
654
+ outputs = (outputs,)
655
+
656
+ grads = torch.autograd.grad(
657
+ outputs,
658
+ [x for x in detached_inputs if x.requires_grad],
659
+ output_grads,
660
+ allow_unused=True
661
+ )
662
+
663
+ # Match gradients to inputs
664
+ grad_iter = iter(grads)
665
+ input_grads = []
666
+ for x in detached_inputs:
667
+ if x.requires_grad:
668
+ input_grads.append(next(grad_iter))
669
+ else:
670
+ input_grads.append(None)
671
+
672
+ return (None, None) + tuple(input_grads)
673
+
674
+
675
+ def get_device_states(*args):
676
+ """Get CUDA RNG state for gradient checkpointing."""
677
+ fwd_gpu_devices = []
678
+ fwd_gpu_states = []
679
+
680
+ for arg in args:
681
+ if isinstance(arg, torch.Tensor) and arg.is_cuda:
682
+ device = arg.device
683
+ if device not in fwd_gpu_devices:
684
+ fwd_gpu_devices.append(device)
685
+ fwd_gpu_states.append(torch.cuda.get_rng_state(device))
686
+
687
+ return fwd_gpu_devices, fwd_gpu_states
688
+
689
+
690
+ def checkpoint(function, *args, preserve_rng_state: bool = True):
691
+ """
692
+ Apply gradient checkpointing to a function.
693
+
694
+ Args:
695
+ function: Function to checkpoint
696
+ *args: Arguments to the function
697
+ preserve_rng_state: Whether to preserve RNG state
698
+
699
+ Returns:
700
+ Function output with gradient checkpointing applied
701
+ """
702
+ return GradientCheckpointFunction.apply(function, preserve_rng_state, *args)
703
+
704
+
705
+ # =============================================================================
706
+ # Mixed Precision Training Utilities
707
+ # =============================================================================
708
+
709
+ class MixedPrecisionTrainer:
710
+ """
711
+ Mixed precision training utilities.
712
+
713
+ Provides automatic mixed precision (AMP) with gradient scaling
714
+ for stable training.
715
+ """
716
+
717
+ def __init__(
718
+ self,
719
+ enabled: bool = True,
720
+ dtype: torch.dtype = torch.float16,
721
+ init_scale: float = 65536.0,
722
+ growth_factor: float = 2.0,
723
+ backoff_factor: float = 0.5,
724
+ growth_interval: int = 2000
725
+ ):
726
+ self.enabled = enabled and torch.cuda.is_available()
727
+ self.dtype = dtype
728
+
729
+ if self.enabled:
730
+ self.scaler = GradScaler(
731
+ init_scale=init_scale,
732
+ growth_factor=growth_factor,
733
+ backoff_factor=backoff_factor,
734
+ growth_interval=growth_interval
735
+ )
736
+ logger.info(f"Mixed precision training enabled with {dtype}")
737
+ else:
738
+ self.scaler = None
739
+
740
+ @property
741
+ def autocast_context(self):
742
+ """Get autocast context manager."""
743
+ return autocast(enabled=self.enabled, dtype=self.dtype)
744
+
745
+ def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
746
+ """Scale loss for gradient stability."""
747
+ if self.scaler is not None:
748
+ return self.scaler.scale(loss)
749
+ return loss
750
+
751
+ def unscale_gradients(self, optimizer):
752
+ """Unscale gradients before clipping."""
753
+ if self.scaler is not None:
754
+ self.scaler.unscale_(optimizer)
755
+
756
+ def step(self, optimizer):
757
+ """Take optimizer step with gradient scaling."""
758
+ if self.scaler is not None:
759
+ self.scaler.step(optimizer)
760
+ self.scaler.update()
761
+ else:
762
+ optimizer.step()
763
+
764
+ def get_scale(self) -> float:
765
+ """Get current gradient scale."""
766
+ if self.scaler is not None:
767
+ return self.scaler.get_scale()
768
+ return 1.0
769
+
770
+
771
+ # =============================================================================
772
+ # Memory Monitoring
773
+ # =============================================================================
774
+
775
+ def get_memory_stats() -> Dict[str, float]:
776
+ """
777
+ Get GPU memory statistics.
778
+
779
+ Returns:
780
+ Dictionary with memory stats in GB
781
+ """
782
+ if not torch.cuda.is_available():
783
+ return {}
784
+
785
+ return {
786
+ 'allocated': torch.cuda.memory_allocated() / 1e9,
787
+ 'reserved': torch.cuda.memory_reserved() / 1e9,
788
+ 'max_allocated': torch.cuda.max_memory_allocated() / 1e9,
789
+ 'max_reserved': torch.cuda.max_memory_reserved() / 1e9
790
+ }
791
+
792
+
793
+ def reset_memory_stats():
794
+ """Reset GPU memory statistics."""
795
+ if torch.cuda.is_available():
796
+ torch.cuda.reset_peak_memory_stats()
797
+
798
+
799
+ def cleanup_memory():
800
+ """Free unused GPU memory."""
801
+ if torch.cuda.is_available():
802
+ torch.cuda.empty_cache()
803
+
804
+
805
+ def log_memory_usage(prefix: str = ""):
806
+ """Log current GPU memory usage."""
807
+ stats = get_memory_stats()
808
+ if stats:
809
+ logger.info(
810
+ f"{prefix}Memory: {stats['allocated']:.2f}GB allocated, "
811
+ f"{stats['max_allocated']:.2f}GB peak"
812
+ )
813
+
814
+
815
+ # =============================================================================
816
+ # Optimization Config
817
+ # =============================================================================
818
+
819
+ class OptimizationConfig:
820
+ """Configuration for optimization settings."""
821
+
822
+ def __init__(
823
+ self,
824
+ use_4bit: bool = False,
825
+ use_8bit: bool = False,
826
+ use_flash_attention: bool = True,
827
+ use_gradient_checkpointing: bool = True,
828
+ use_fused_ops: bool = True,
829
+ use_rope: bool = True,
830
+ gradient_accumulation_steps: int = 4,
831
+ mixed_precision: str = "fp16", # fp16, bf16, or fp32
832
+ group_size: int = 64,
833
+ chunk_size: int = 1024
834
+ ):
835
+ self.use_4bit = use_4bit
836
+ self.use_8bit = use_8bit
837
+ self.use_flash_attention = use_flash_attention
838
+ self.use_gradient_checkpointing = use_gradient_checkpointing
839
+ self.use_fused_ops = use_fused_ops
840
+ self.use_rope = use_rope
841
+ self.gradient_accumulation_steps = gradient_accumulation_steps
842
+ self.mixed_precision = mixed_precision
843
+ self.group_size = group_size
844
+ self.chunk_size = chunk_size
845
+
846
+ # Validate
847
+ if use_4bit and use_8bit:
848
+ raise ValueError("Cannot use both 4-bit and 8-bit quantization")
849
+
850
+ if mixed_precision not in ["fp16", "bf16", "fp32"]:
851
+ raise ValueError(f"Invalid mixed_precision: {mixed_precision}")
852
+
853
+ @property
854
+ def compute_dtype(self) -> torch.dtype:
855
+ """Get compute dtype based on config."""
856
+ if self.mixed_precision == "bf16":
857
+ return torch.bfloat16
858
+ elif self.mixed_precision == "fp16":
859
+ return torch.float16
860
+ else:
861
+ return torch.float32
862
+
863
+ def __repr__(self):
864
+ return (
865
+ f"OptimizationConfig("
866
+ f"use_4bit={self.use_4bit}, "
867
+ f"use_flash_attention={self.use_flash_attention}, "
868
+ f"use_gradient_checkpointing={self.use_gradient_checkpointing}, "
869
+ f"mixed_precision={self.mixed_precision})"
870
+ )