liger-kernel 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
@@ -0,0 +1,213 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class LigerFusedLinearRLHFBase(torch.autograd.Function):
8
+ @staticmethod
9
+ def forward(
10
+ ctx,
11
+ _input,
12
+ weight,
13
+ attention_mask,
14
+ rewards,
15
+ bias=None,
16
+ loss_fn=None,
17
+ num_generations=4,
18
+ beta=0.1,
19
+ compiled=True,
20
+ use_ref_model=False,
21
+ ref_input=None,
22
+ ref_weight=None,
23
+ ref_bias=None,
24
+ ):
25
+ """Chunked forward pass for RLHF loss computation."""
26
+ # Save for backward
27
+ ctx.beta = beta
28
+ ctx.rewards = rewards
29
+
30
+ # Initialize accumulators
31
+ loss_acc = torch.zeros((), device=_input.device)
32
+ grad_weight = torch.zeros_like(weight) # [V, H]
33
+ grad_inputs = []
34
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
35
+ aggregated_metrics = []
36
+
37
+ # Create a partial function with fixed arguments
38
+ compute_loss = partial(
39
+ LigerFusedLinearRLHFBase._compute_chunk_loss,
40
+ beta=beta,
41
+ use_ref_model=use_ref_model,
42
+ ref_weight=ref_weight,
43
+ ref_bias=ref_bias,
44
+ rlhf_loss_fn=loss_fn,
45
+ )
46
+
47
+ def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
48
+ """Fused forward and backward for a chunk."""
49
+ if bias is not None:
50
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
51
+ input_chunk, # arg 0
52
+ weight, # arg 1
53
+ attention_mask_chunk, # arg 2
54
+ rewards_chunk, # arg 3
55
+ ref_input_chunk, # arg 4
56
+ bias, # arg 5
57
+ )
58
+ else:
59
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
60
+ input_chunk, # arg 0
61
+ weight, # arg 1
62
+ attention_mask_chunk, # arg 2
63
+ rewards_chunk, # arg 3
64
+ ref_input_chunk, # arg 4
65
+ )
66
+
67
+ def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
68
+ if bias is not None:
69
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
70
+ input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
71
+ )
72
+ grad_bias.add_(chunk_grad_bias)
73
+ else:
74
+ (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
75
+ input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
76
+ )
77
+
78
+ # Accumulate gradients and loss
79
+ grad_weight.add_(chunk_grad_weight)
80
+ grad_inputs.append(chunk_grad_input)
81
+ loss_acc.add_(chunk_loss)
82
+
83
+ # Initialize storage for metrics on first chunk
84
+ if len(aggregated_metrics) == 0:
85
+ for metric in chunk_metrics:
86
+ if metric.ndim == 0:
87
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
88
+ else:
89
+ aggregated_metrics.append([])
90
+
91
+ # Accumulate metrics
92
+ for i, metric in enumerate(chunk_metrics):
93
+ if metric.ndim == 0:
94
+ aggregated_metrics[i].add_(metric)
95
+ else:
96
+ aggregated_metrics[i].append(metric)
97
+
98
+ if compiled:
99
+ accumulate_chunk = torch.compile(accumulate_chunk)
100
+
101
+ # Process input in chunks
102
+ chunks = max(1, _input.shape[0] // num_generations)
103
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
104
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
105
+ _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
106
+ _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
107
+
108
+ for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
109
+ _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
110
+ ):
111
+ # Mark dynamic dimensions
112
+ torch._dynamo.mark_dynamic(input_chunk, 1)
113
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
114
+ if ref_input_chunk is not None:
115
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
116
+
117
+ accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
118
+
119
+ # Scale accumulated loss by number of chunks since we're averaging
120
+ loss_acc = loss_acc / chunks
121
+
122
+ # Combine gradients
123
+ grad_input = torch.cat(grad_inputs, dim=0)
124
+
125
+ # Save for backward
126
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
127
+
128
+ # Finalize metrics
129
+ final_metrics = []
130
+ for metric in aggregated_metrics:
131
+ if isinstance(metric, list):
132
+ final_metrics.append(torch.cat(metric, dim=0))
133
+ else:
134
+ final_metrics.append(metric / chunks)
135
+
136
+ return loss_acc, tuple(final_metrics)
137
+
138
+ @staticmethod
139
+ def _compute_chunk_loss(
140
+ input_chunk,
141
+ weight,
142
+ attention_mask_chunk,
143
+ rewards_chunk,
144
+ ref_input_chunk=None,
145
+ bias=None,
146
+ beta=0.1,
147
+ use_ref_model=False,
148
+ ref_weight=None,
149
+ ref_bias=None,
150
+ rlhf_loss_fn=None,
151
+ ):
152
+ """Compute loss for a single chunk."""
153
+ # Get policy log probabilities using chunk_forward
154
+ log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
155
+
156
+ # Get reference log probabilities if needed
157
+ ref_log_probs = None
158
+ if use_ref_model and ref_input_chunk is not None:
159
+ with torch.no_grad():
160
+ ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
161
+
162
+ # Compute chunk loss and metrics using the provided loss function
163
+ chunk_loss, chunk_metrics = rlhf_loss_fn(
164
+ log_probs=log_probs,
165
+ attention_mask=attention_mask_chunk,
166
+ rewards=rewards_chunk,
167
+ ref_log_probs=ref_log_probs,
168
+ beta=beta,
169
+ )
170
+
171
+ return chunk_loss, (logits_mean, *chunk_metrics)
172
+
173
+ @staticmethod
174
+ def chunk_forward(input_chunk, weight, bias=None):
175
+ """Forward pass computation for a single chunk without explicit reshaping."""
176
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
177
+ logits = torch.matmul(input_chunk, weight.t())
178
+ if bias is not None:
179
+ logits = logits + bias # Broadcasts bias to [B, T, V]
180
+
181
+ # Compute log probabilities using softmax over the last dimension
182
+ log_probs = F.log_softmax(logits.float(), dim=-1)
183
+
184
+ # Monitoring: compute mean of logits
185
+ batch_size, seq_len, _ = input_chunk.shape
186
+ logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
187
+ return log_probs, logits, logits_mean
188
+
189
+ @staticmethod
190
+ def backward(ctx, grad_output, *grad_metrics):
191
+ """Backward pass for RLHF loss."""
192
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
193
+ if grad_output != 1.0:
194
+ grad_input = grad_input * grad_output
195
+ grad_weight = grad_weight * grad_output
196
+ if grad_bias is not None:
197
+ grad_bias = grad_bias * grad_output
198
+
199
+ return (
200
+ grad_input,
201
+ grad_weight,
202
+ None, # grad_attention_mask
203
+ None, # grad_rewards
204
+ grad_bias,
205
+ None, # grad_loss_fn
206
+ None, # grad_chunk_size
207
+ None, # grad_beta
208
+ None, # grad_compiled
209
+ None, # grad_use_ref_model
210
+ None, # grad_ref_input
211
+ None, # grad_ref_weight
212
+ None, # grad_ref_bias
213
+ )
@@ -0,0 +1,160 @@
1
+ import torch
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
4
+
5
+
6
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
7
+ @staticmethod
8
+ def rlhf_loss_fn(
9
+ log_probs,
10
+ attention_mask,
11
+ rewards,
12
+ ref_log_probs=None,
13
+ beta=0.1,
14
+ **kwargs,
15
+ ):
16
+ """GRPO Loss Function matching GRPOTrainer implementation."""
17
+ # Get chosen token probabilities
18
+ chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
19
+ chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
20
+ -1
21
+ ) # (batch_size, seq_len)
22
+
23
+ # Get reference model probabilities
24
+ if ref_log_probs is not None:
25
+ with torch.no_grad():
26
+ ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
27
+ else:
28
+ ref_token_logprobs = chosen_token_logprobs.detach()
29
+
30
+ # Compute advantages per batch entry in a grouped fashion
31
+ mean_grouped_rewards = rewards.mean() # [batch_size,]
32
+ std_grouped_rewards = rewards.std() # [batch_size,]
33
+
34
+ # Calculate advantages using the same epsilon as in GRPOTrainer
35
+ eps = 1e-4
36
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
37
+
38
+ # Compute policy gradient loss with importance sampling ratio
39
+ ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
40
+ policy_loss = -ratio * advantages.unsqueeze(1)
41
+
42
+ # Compute KL penalty
43
+ kl_div = (
44
+ torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
45
+ )
46
+
47
+ # Combine losses
48
+ per_token_loss = policy_loss + beta * kl_div
49
+
50
+ # Apply masking and normalize
51
+ masked_loss = per_token_loss * attention_mask
52
+ seq_lengths = attention_mask.sum()
53
+ seq_lengths = torch.clamp(seq_lengths, min=1.0)
54
+ loss = masked_loss.sum() / seq_lengths
55
+
56
+ # Calculate metrics
57
+ metrics = (
58
+ chosen_token_logprobs.mean(), # mean log prob
59
+ chosen_token_logprobs.std(), # std log prob
60
+ log_probs.mean(), # mean all log probs
61
+ ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
62
+ )
63
+
64
+ return loss, metrics
65
+
66
+ @staticmethod
67
+ def forward(
68
+ ctx,
69
+ _input,
70
+ weight,
71
+ attention_mask,
72
+ rewards,
73
+ bias=None,
74
+ ref_input=None,
75
+ ref_weight=None,
76
+ ref_bias=None,
77
+ beta=0.1,
78
+ compiled=True,
79
+ use_ref_model=True,
80
+ num_generations=1,
81
+ ):
82
+ return LigerFusedLinearRLHFBase.forward(
83
+ ctx=ctx,
84
+ _input=_input,
85
+ weight=weight,
86
+ attention_mask=attention_mask,
87
+ loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn,
88
+ rewards=rewards,
89
+ bias=bias,
90
+ ref_input=ref_input,
91
+ ref_weight=ref_weight,
92
+ ref_bias=ref_bias,
93
+ beta=beta,
94
+ compiled=compiled,
95
+ use_ref_model=use_ref_model,
96
+ num_generations=num_generations,
97
+ )
98
+
99
+ @staticmethod
100
+ def backward(ctx, grad_output, *grad_metrics):
101
+ """Backward pass for GRPO loss.
102
+
103
+ Args:
104
+ grad_output: Gradient of the loss (scalar)
105
+ grad_metrics: Gradients of the metrics (not used in backward computation)
106
+ """
107
+ grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
108
+ return (
109
+ *grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
110
+ None, # grad_ref_input
111
+ None, # grad_ref_weight
112
+ None, # grad_ref_bias
113
+ None, # grad_beta
114
+ None, # grad_compiled
115
+ None, # grad_use_ref_model
116
+ None, # grad_num_generations
117
+ )
118
+
119
+
120
+ class LigerFusedLinearGRPOLoss(torch.nn.Module):
121
+ """Fused linear layer with GRPO loss."""
122
+
123
+ def __init__(
124
+ self,
125
+ beta: float = 0.1,
126
+ compiled: bool = True,
127
+ use_ref_model: bool = True,
128
+ num_generations: int = 1,
129
+ ):
130
+ super().__init__()
131
+ self.beta = beta
132
+ self.compiled = compiled
133
+ self.use_ref_model = use_ref_model
134
+ self.num_generations = num_generations
135
+
136
+ def forward(
137
+ self,
138
+ _input,
139
+ lin_weight,
140
+ attention_mask,
141
+ rewards,
142
+ bias=None,
143
+ ref_input=None,
144
+ ref_weight=None,
145
+ ref_bias=None,
146
+ ):
147
+ return LigerFusedLinearGRPOFunction.apply(
148
+ _input,
149
+ lin_weight,
150
+ attention_mask,
151
+ rewards,
152
+ bias,
153
+ ref_input,
154
+ ref_weight,
155
+ ref_bias,
156
+ self.beta,
157
+ self.compiled,
158
+ self.use_ref_model,
159
+ self.num_generations,
160
+ )
@@ -43,20 +43,20 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
43
43
  3. Maintain reasonable distance from the reference model
44
44
 
45
45
  Args:
46
- chosen_logps: Log probabilities of chosen tokens (batch_size,)
47
- rejected_logps: Log probabilities of rejected tokens (batch_size,)
46
+ average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
47
+ preference_labels_chunk: Preference labels for the chunk (batch_size,)
48
48
  full_target: Non chunked full target tensor
49
- ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
50
- ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
51
- beta: Weight for the direct preference loss
49
+ ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
50
+ beta: Weight for the KTO loss
52
51
  kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
53
52
  Returns:
54
- Tuple of (loss, chosen_rewards, rejected_rewards):
55
53
  - loss: The KTO loss value
56
- - chosen_rewards: Reward signals for chosen responses (detached)
57
- - rejected_rewards: Reward signals for rejected responses (detached)
58
54
  """
59
- logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
55
+ if ref_average_log_prob_chunk is not None:
56
+ logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
57
+ else:
58
+ logratios_chunk = average_log_prob_chunk
59
+
60
60
  multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
61
61
  if kl is not None:
62
62
  losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
@@ -289,9 +289,9 @@ def cross_entropy_forward(
289
289
  weight_sum = 0.0
290
290
  if weight is not None:
291
291
  assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
292
- assert torch.is_floating_point(
293
- weight
294
- ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
292
+ assert torch.is_floating_point(weight), (
293
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
294
+ )
295
295
  sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
296
296
  weight_sum = weight.sum().item()
297
297
  # ensure weight is contiguous
@@ -58,9 +58,9 @@ def fused_linear_cross_entropy_forward(
58
58
  ce_weight_sum = 0.0
59
59
  if ce_weight is not None:
60
60
  assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
61
- assert torch.is_floating_point(
62
- ce_weight
63
- ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
61
+ assert torch.is_floating_point(ce_weight), (
62
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
63
+ )
64
64
  total_sum_non_ignore_ce_weight = (
65
65
  torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
66
66
  )
@@ -195,9 +195,9 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
195
195
  """
196
196
  has_label = False
197
197
  if shift_labels is not None:
198
- assert shift_labels.shape == (
199
- teacher_input.shape[0],
200
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
198
+ assert shift_labels.shape == (teacher_input.shape[0],), (
199
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
200
+ )
201
201
  shift_labels = shift_labels.contiguous()
202
202
  has_label = True
203
203
 
liger_kernel/ops/jsd.py CHANGED
@@ -157,9 +157,9 @@ class LigerJSDFunction(torch.autograd.Function):
157
157
  """
158
158
  has_label = False
159
159
  if shift_labels is not None:
160
- assert shift_labels.shape == (
161
- _input.shape[0],
162
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
160
+ assert shift_labels.shape == (_input.shape[0],), (
161
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
162
+ )
163
163
  shift_labels = shift_labels.contiguous()
164
164
  has_label = True
165
165
 
@@ -57,13 +57,14 @@ def _layer_norm_forward_kernel(
57
57
  B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
58
 
59
59
  mean = tl.sum(X_row, axis=0) / n_cols
60
- var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols
60
+ Xmm = tl.where(mask, X_row - mean, 0)
61
+ var = tl.sum(Xmm * Xmm, axis=0) / n_cols
61
62
  rstd = rsqrt(var + eps)
62
63
 
63
64
  tl.store(Mean_ptr, mean)
64
65
  tl.store(RSTD_ptr, rstd)
65
66
 
66
- Y_row = (X_row - mean) * rstd * W_row + B_row
67
+ Y_row = Xmm * rstd * W_row + B_row
67
68
 
68
69
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
69
70
 
@@ -147,9 +148,11 @@ def layer_norm_forward(X, W, B, eps):
147
148
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
148
149
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
149
150
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
- assert (
151
- X.shape[1] == W.shape[0]
152
- ), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
151
+ if X.shape[1] != W.shape[0]:
152
+ raise ValueError(
153
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
154
+ f"must match weight size (W.shape[0]={W.shape[0]})"
155
+ )
153
156
 
154
157
  _layer_norm_forward_kernel[(n_rows,)](
155
158
  Y,
@@ -190,11 +193,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
190
193
 
191
194
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
192
195
  if n_cols > BLOCK_SIZE:
193
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
196
+ raise RuntimeError(
197
+ f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
198
+ )
194
199
 
195
200
  rows_per_program = math.ceil(n_rows / sm_count)
196
201
  grid = (sm_count,)
197
- triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
202
+ triton_dtype = (
203
+ tl.float32
204
+ if X.dtype == torch.float32
205
+ else tl.bfloat16
206
+ if X.dtype == torch.bfloat16
207
+ else tl.float16
208
+ if X.dtype == torch.float16
209
+ else tl.float32 # fallback to float32 for other types
210
+ )
198
211
  _layer_norm_backward_kernel[grid](
199
212
  X,
200
213
  W,