liger-kernel-nightly 0.5.3.dev20250218225514__py3-none-any.whl → 0.5.3.dev20250219232423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
@@ -0,0 +1,213 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class LigerFusedLinearRLHFBase(torch.autograd.Function):
8
+ @staticmethod
9
+ def forward(
10
+ ctx,
11
+ _input,
12
+ weight,
13
+ attention_mask,
14
+ rewards,
15
+ bias=None,
16
+ loss_fn=None,
17
+ num_generations=4,
18
+ beta=0.1,
19
+ compiled=True,
20
+ use_ref_model=False,
21
+ ref_input=None,
22
+ ref_weight=None,
23
+ ref_bias=None,
24
+ ):
25
+ """Chunked forward pass for RLHF loss computation."""
26
+ # Save for backward
27
+ ctx.beta = beta
28
+ ctx.rewards = rewards
29
+
30
+ # Initialize accumulators
31
+ loss_acc = torch.zeros((), device=_input.device)
32
+ grad_weight = torch.zeros_like(weight) # [V, H]
33
+ grad_inputs = []
34
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
35
+ aggregated_metrics = []
36
+
37
+ # Create a partial function with fixed arguments
38
+ compute_loss = partial(
39
+ LigerFusedLinearRLHFBase._compute_chunk_loss,
40
+ beta=beta,
41
+ use_ref_model=use_ref_model,
42
+ ref_weight=ref_weight,
43
+ ref_bias=ref_bias,
44
+ rlhf_loss_fn=loss_fn,
45
+ )
46
+
47
+ def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
48
+ """Fused forward and backward for a chunk."""
49
+ if bias is not None:
50
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
51
+ input_chunk, # arg 0
52
+ weight, # arg 1
53
+ attention_mask_chunk, # arg 2
54
+ rewards_chunk, # arg 3
55
+ ref_input_chunk, # arg 4
56
+ bias, # arg 5
57
+ )
58
+ else:
59
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
60
+ input_chunk, # arg 0
61
+ weight, # arg 1
62
+ attention_mask_chunk, # arg 2
63
+ rewards_chunk, # arg 3
64
+ ref_input_chunk, # arg 4
65
+ )
66
+
67
+ def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
68
+ if bias is not None:
69
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
70
+ input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
71
+ )
72
+ grad_bias.add_(chunk_grad_bias)
73
+ else:
74
+ (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
75
+ input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
76
+ )
77
+
78
+ # Accumulate gradients and loss
79
+ grad_weight.add_(chunk_grad_weight)
80
+ grad_inputs.append(chunk_grad_input)
81
+ loss_acc.add_(chunk_loss)
82
+
83
+ # Initialize storage for metrics on first chunk
84
+ if len(aggregated_metrics) == 0:
85
+ for metric in chunk_metrics:
86
+ if metric.ndim == 0:
87
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
88
+ else:
89
+ aggregated_metrics.append([])
90
+
91
+ # Accumulate metrics
92
+ for i, metric in enumerate(chunk_metrics):
93
+ if metric.ndim == 0:
94
+ aggregated_metrics[i].add_(metric)
95
+ else:
96
+ aggregated_metrics[i].append(metric)
97
+
98
+ if compiled:
99
+ accumulate_chunk = torch.compile(accumulate_chunk)
100
+
101
+ # Process input in chunks
102
+ chunks = max(1, _input.shape[0] // num_generations)
103
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
104
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
105
+ _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
106
+ _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
107
+
108
+ for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
109
+ _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
110
+ ):
111
+ # Mark dynamic dimensions
112
+ torch._dynamo.mark_dynamic(input_chunk, 1)
113
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
114
+ if ref_input_chunk is not None:
115
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
116
+
117
+ accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
118
+
119
+ # Scale accumulated loss by number of chunks since we're averaging
120
+ loss_acc = loss_acc / chunks
121
+
122
+ # Combine gradients
123
+ grad_input = torch.cat(grad_inputs, dim=0)
124
+
125
+ # Save for backward
126
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
127
+
128
+ # Finalize metrics
129
+ final_metrics = []
130
+ for metric in aggregated_metrics:
131
+ if isinstance(metric, list):
132
+ final_metrics.append(torch.cat(metric, dim=0))
133
+ else:
134
+ final_metrics.append(metric / chunks)
135
+
136
+ return loss_acc, tuple(final_metrics)
137
+
138
+ @staticmethod
139
+ def _compute_chunk_loss(
140
+ input_chunk,
141
+ weight,
142
+ attention_mask_chunk,
143
+ rewards_chunk,
144
+ ref_input_chunk=None,
145
+ bias=None,
146
+ beta=0.1,
147
+ use_ref_model=False,
148
+ ref_weight=None,
149
+ ref_bias=None,
150
+ rlhf_loss_fn=None,
151
+ ):
152
+ """Compute loss for a single chunk."""
153
+ # Get policy log probabilities using chunk_forward
154
+ log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
155
+
156
+ # Get reference log probabilities if needed
157
+ ref_log_probs = None
158
+ if use_ref_model and ref_input_chunk is not None:
159
+ with torch.no_grad():
160
+ ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
161
+
162
+ # Compute chunk loss and metrics using the provided loss function
163
+ chunk_loss, chunk_metrics = rlhf_loss_fn(
164
+ log_probs=log_probs,
165
+ attention_mask=attention_mask_chunk,
166
+ rewards=rewards_chunk,
167
+ ref_log_probs=ref_log_probs,
168
+ beta=beta,
169
+ )
170
+
171
+ return chunk_loss, (logits_mean, *chunk_metrics)
172
+
173
+ @staticmethod
174
+ def chunk_forward(input_chunk, weight, bias=None):
175
+ """Forward pass computation for a single chunk without explicit reshaping."""
176
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
177
+ logits = torch.matmul(input_chunk, weight.t())
178
+ if bias is not None:
179
+ logits = logits + bias # Broadcasts bias to [B, T, V]
180
+
181
+ # Compute log probabilities using softmax over the last dimension
182
+ log_probs = F.log_softmax(logits.float(), dim=-1)
183
+
184
+ # Monitoring: compute mean of logits
185
+ batch_size, seq_len, _ = input_chunk.shape
186
+ logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
187
+ return log_probs, logits, logits_mean
188
+
189
+ @staticmethod
190
+ def backward(ctx, grad_output, *grad_metrics):
191
+ """Backward pass for RLHF loss."""
192
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
193
+ if grad_output != 1.0:
194
+ grad_input = grad_input * grad_output
195
+ grad_weight = grad_weight * grad_output
196
+ if grad_bias is not None:
197
+ grad_bias = grad_bias * grad_output
198
+
199
+ return (
200
+ grad_input,
201
+ grad_weight,
202
+ None, # grad_attention_mask
203
+ None, # grad_rewards
204
+ grad_bias,
205
+ None, # grad_loss_fn
206
+ None, # grad_chunk_size
207
+ None, # grad_beta
208
+ None, # grad_compiled
209
+ None, # grad_use_ref_model
210
+ None, # grad_ref_input
211
+ None, # grad_ref_weight
212
+ None, # grad_ref_bias
213
+ )
@@ -0,0 +1,160 @@
1
+ import torch
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
4
+
5
+
6
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
7
+ @staticmethod
8
+ def rlhf_loss_fn(
9
+ log_probs,
10
+ attention_mask,
11
+ rewards,
12
+ ref_log_probs=None,
13
+ beta=0.1,
14
+ **kwargs,
15
+ ):
16
+ """GRPO Loss Function matching GRPOTrainer implementation."""
17
+ # Get chosen token probabilities
18
+ chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
19
+ chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
20
+ -1
21
+ ) # (batch_size, seq_len)
22
+
23
+ # Get reference model probabilities
24
+ if ref_log_probs is not None:
25
+ with torch.no_grad():
26
+ ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
27
+ else:
28
+ ref_token_logprobs = chosen_token_logprobs.detach()
29
+
30
+ # Compute advantages per batch entry in a grouped fashion
31
+ mean_grouped_rewards = rewards.mean() # [batch_size,]
32
+ std_grouped_rewards = rewards.std() # [batch_size,]
33
+
34
+ # Calculate advantages using the same epsilon as in GRPOTrainer
35
+ eps = 1e-4
36
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
37
+
38
+ # Compute policy gradient loss with importance sampling ratio
39
+ ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
40
+ policy_loss = -ratio * advantages.unsqueeze(1)
41
+
42
+ # Compute KL penalty
43
+ kl_div = (
44
+ torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
45
+ )
46
+
47
+ # Combine losses
48
+ per_token_loss = policy_loss + beta * kl_div
49
+
50
+ # Apply masking and normalize
51
+ masked_loss = per_token_loss * attention_mask
52
+ seq_lengths = attention_mask.sum()
53
+ seq_lengths = torch.clamp(seq_lengths, min=1.0)
54
+ loss = masked_loss.sum() / seq_lengths
55
+
56
+ # Calculate metrics
57
+ metrics = (
58
+ chosen_token_logprobs.mean(), # mean log prob
59
+ chosen_token_logprobs.std(), # std log prob
60
+ log_probs.mean(), # mean all log probs
61
+ ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
62
+ )
63
+
64
+ return loss, metrics
65
+
66
+ @staticmethod
67
+ def forward(
68
+ ctx,
69
+ _input,
70
+ weight,
71
+ attention_mask,
72
+ rewards,
73
+ bias=None,
74
+ ref_input=None,
75
+ ref_weight=None,
76
+ ref_bias=None,
77
+ beta=0.1,
78
+ compiled=True,
79
+ use_ref_model=True,
80
+ num_generations=1,
81
+ ):
82
+ return LigerFusedLinearRLHFBase.forward(
83
+ ctx=ctx,
84
+ _input=_input,
85
+ weight=weight,
86
+ attention_mask=attention_mask,
87
+ loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn,
88
+ rewards=rewards,
89
+ bias=bias,
90
+ ref_input=ref_input,
91
+ ref_weight=ref_weight,
92
+ ref_bias=ref_bias,
93
+ beta=beta,
94
+ compiled=compiled,
95
+ use_ref_model=use_ref_model,
96
+ num_generations=num_generations,
97
+ )
98
+
99
+ @staticmethod
100
+ def backward(ctx, grad_output, *grad_metrics):
101
+ """Backward pass for GRPO loss.
102
+
103
+ Args:
104
+ grad_output: Gradient of the loss (scalar)
105
+ grad_metrics: Gradients of the metrics (not used in backward computation)
106
+ """
107
+ grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
108
+ return (
109
+ *grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
110
+ None, # grad_ref_input
111
+ None, # grad_ref_weight
112
+ None, # grad_ref_bias
113
+ None, # grad_beta
114
+ None, # grad_compiled
115
+ None, # grad_use_ref_model
116
+ None, # grad_num_generations
117
+ )
118
+
119
+
120
+ class LigerFusedLinearGRPOLoss(torch.nn.Module):
121
+ """Fused linear layer with GRPO loss."""
122
+
123
+ def __init__(
124
+ self,
125
+ beta: float = 0.1,
126
+ compiled: bool = True,
127
+ use_ref_model: bool = True,
128
+ num_generations: int = 1,
129
+ ):
130
+ super().__init__()
131
+ self.beta = beta
132
+ self.compiled = compiled
133
+ self.use_ref_model = use_ref_model
134
+ self.num_generations = num_generations
135
+
136
+ def forward(
137
+ self,
138
+ _input,
139
+ lin_weight,
140
+ attention_mask,
141
+ rewards,
142
+ bias=None,
143
+ ref_input=None,
144
+ ref_weight=None,
145
+ ref_bias=None,
146
+ ):
147
+ return LigerFusedLinearGRPOFunction.apply(
148
+ _input,
149
+ lin_weight,
150
+ attention_mask,
151
+ rewards,
152
+ bias,
153
+ ref_input,
154
+ ref_weight,
155
+ ref_bias,
156
+ self.beta,
157
+ self.compiled,
158
+ self.use_ref_model,
159
+ self.num_generations,
160
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250218225514
3
+ Version: 0.5.3.dev20250219232423
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -2,13 +2,15 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
3
3
  liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
5
- liger_kernel/chunked_loss/__init__.py,sha256=48m-8IMOAReZbi0HK5aV-KGBp2IsZSwFvdnzTNrS4bk,516
5
+ liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=wgjnwzLfrMUwV5mXgrq6G1YfQKWnbiFJegaP48BGJHY,4509
8
8
  liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=5V8rdva89WyHVbmJ8JOmC4DYNOR6ByXfx3qlUieOZkI,11002
10
10
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=idK9V9NivoVITqVpiG0fEGUHSvinYWkn9-EYXZjR-KQ,18356
11
+ liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=sAApL4GQ3YL2F-ymIAF61GCpFfBgFcWF5LB4Gzd7LgY,8044
11
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
13
+ liger_kernel/chunked_loss/grpo_loss.py,sha256=M5qlQR-v5Rh8N3P3dPGNhOKygDFJ4516_rJaVPzU_-c,4980
12
14
  liger_kernel/chunked_loss/jsd_loss.py,sha256=yRCQdvd3ruTWP4A_BfU8VcZ6LepSUfO0Ob7stGnueQY,6052
13
15
  liger_kernel/chunked_loss/kto_loss.py,sha256=eVNW6HVCAm32shpfhbRlk92Flnjd7G32v0gK9DUUSOQ,5655
14
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
@@ -61,9 +63,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
61
63
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
62
64
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
63
65
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
64
- liger_kernel_nightly-0.5.3.dev20250218225514.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
65
- liger_kernel_nightly-0.5.3.dev20250218225514.dist-info/METADATA,sha256=vfQVWTh3i9uSIfPngwGYYRQRdv4VlbYPbr-vciAqdv4,21625
66
- liger_kernel_nightly-0.5.3.dev20250218225514.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
67
- liger_kernel_nightly-0.5.3.dev20250218225514.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
68
- liger_kernel_nightly-0.5.3.dev20250218225514.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
69
- liger_kernel_nightly-0.5.3.dev20250218225514.dist-info/RECORD,,
66
+ liger_kernel_nightly-0.5.3.dev20250219232423.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
67
+ liger_kernel_nightly-0.5.3.dev20250219232423.dist-info/METADATA,sha256=cwUbT2K8osL4Xudlk83LjnQ3vtiZ1Ip8qhvSrXYcaTo,21625
68
+ liger_kernel_nightly-0.5.3.dev20250219232423.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
69
+ liger_kernel_nightly-0.5.3.dev20250219232423.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
70
+ liger_kernel_nightly-0.5.3.dev20250219232423.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
71
+ liger_kernel_nightly-0.5.3.dev20250219232423.dist-info/RECORD,,