liger-kernel 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cpo_loss.py +51 -11
  3. liger_kernel/chunked_loss/dpo_loss.py +30 -4
  4. liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
  5. liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
  6. liger_kernel/chunked_loss/fused_linear_rlhf.py +240 -0
  7. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
  8. liger_kernel/chunked_loss/grpo_loss.py +194 -0
  9. liger_kernel/chunked_loss/jsd_loss.py +31 -6
  10. liger_kernel/chunked_loss/kto_loss.py +53 -15
  11. liger_kernel/chunked_loss/orpo_loss.py +37 -5
  12. liger_kernel/chunked_loss/simpo_loss.py +47 -11
  13. liger_kernel/ops/cross_entropy.py +7 -3
  14. liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
  15. liger_kernel/ops/fused_linear_jsd.py +3 -3
  16. liger_kernel/ops/jsd.py +3 -3
  17. liger_kernel/ops/layer_norm.py +20 -7
  18. liger_kernel/ops/tvd.py +207 -0
  19. liger_kernel/ops/utils.py +1 -2
  20. liger_kernel/transformers/__init__.py +4 -0
  21. liger_kernel/transformers/cross_entropy.py +3 -3
  22. liger_kernel/transformers/functional.py +17 -0
  23. liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
  24. liger_kernel/transformers/group_norm.py +6 -6
  25. liger_kernel/transformers/model/olmo2.py +124 -0
  26. liger_kernel/transformers/model/qwen2_5_vl.py +205 -0
  27. liger_kernel/transformers/monkey_patch.py +239 -27
  28. liger_kernel/transformers/tvd.py +13 -0
  29. liger_kernel/utils.py +48 -1
  30. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/METADATA +19 -4
  31. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/RECORD +35 -29
  32. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/WHEEL +1 -1
  33. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/LICENSE +0 -0
  34. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/NOTICE +0 -0
  35. {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
@@ -39,8 +39,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
39
39
 
40
40
  return loss, chosen_rewards, rejected_rewards
41
41
 
42
- @staticmethod
42
+ @classmethod
43
43
  def forward(
44
+ cls,
44
45
  ctx,
45
46
  _input,
46
47
  weight,
@@ -52,27 +53,48 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
52
53
  label_smoothing=0.0,
53
54
  compute_nll_loss=True,
54
55
  compiled=True,
56
+ average_log_prob=False,
57
+ chunk_size=1,
55
58
  ):
56
- return LigerFusedLinearPreferenceBase.forward(
57
- ctx,
58
- _input,
59
- weight,
60
- target,
61
- bias,
62
- loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
59
+ """
60
+ Fused linear layer with CPO loss.
61
+ Args:
62
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
63
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
64
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
65
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
66
+ ignore_index (int): Index to ignore in loss computation
67
+ beta (float): Weight for the odds ratio loss
68
+ alpha (float): Weight for the alpha parameter
69
+ label_smoothing (float): Label smoothing factor
70
+ compute_nll_loss (bool): Whether to compute the NLL loss
71
+ compiled (bool): Whether to use torch compile
72
+ average_log_prob (bool): Whether to average the log probability per non-masked token
73
+ chunk_size (int): Size of chunks for processing.
74
+ Returns:
75
+ torch.Tensor: Computed loss
76
+ """
77
+ return super().forward(
78
+ cls=cls,
79
+ ctx=ctx,
80
+ _input=_input,
81
+ weight=weight,
82
+ target=target,
83
+ bias=bias,
63
84
  ignore_index=ignore_index,
64
85
  alpha=alpha,
65
86
  beta=beta,
66
87
  label_smoothing=label_smoothing,
67
88
  compute_nll_loss=compute_nll_loss,
68
- average_log_prob=False,
89
+ average_log_prob=average_log_prob,
69
90
  compiled=compiled,
91
+ chunk_size=chunk_size,
70
92
  )
71
93
 
72
94
  @staticmethod
73
95
  def backward(ctx, *grad_output):
74
96
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
75
- return *grads, None, None, None, None, None, None
97
+ return *grads, None, None, None, None, None, None, None, None
76
98
 
77
99
 
78
100
  class LigerFusedLinearCPOLoss(torch.nn.Module):
@@ -88,11 +110,19 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
88
110
  label_smoothing: float = 0.0,
89
111
  compute_nll_loss: bool = True,
90
112
  compiled: bool = True,
113
+ average_log_prob: bool = False,
114
+ chunk_size: int = 1,
91
115
  ):
92
116
  """
93
117
  Args:
94
118
  ignore_index (int): Index to ignore in the loss.
95
119
  beta (float): Weight for the odds ratio loss.
120
+ alpha (float): Weight for the alpha parameter.
121
+ label_smoothing (float): Label smoothing factor.
122
+ compute_nll_loss (bool): Whether to compute the NLL loss.
123
+ compiled (bool): Whether to use the torch compiled kernel.
124
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
125
+ chunk_size (int): Size of chunks for processing.
96
126
  """
97
127
  super().__init__()
98
128
  self.ignore_index = ignore_index
@@ -101,8 +131,16 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
101
131
  self.label_smoothing = label_smoothing
102
132
  self.compute_nll_loss = compute_nll_loss
103
133
  self.compiled = compiled
134
+ self.average_log_prob = average_log_prob
135
+ self.chunk_size = chunk_size
104
136
 
105
- def forward(self, lin_weight, _input, target, bias=None):
137
+ def forward(
138
+ self,
139
+ lin_weight,
140
+ _input,
141
+ target,
142
+ bias=None,
143
+ ):
106
144
  return LigerFusedLinearCPOFunction.apply(
107
145
  _input,
108
146
  lin_weight,
@@ -114,4 +152,6 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
114
152
  self.label_smoothing,
115
153
  self.compute_nll_loss,
116
154
  self.compiled,
155
+ self.average_log_prob,
156
+ self.chunk_size,
117
157
  )
@@ -52,8 +52,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
52
52
  loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
53
53
  return loss, chosen_rewards, rejected_rewards
54
54
 
55
- @staticmethod
55
+ @classmethod
56
56
  def forward(
57
+ cls,
57
58
  ctx,
58
59
  _input,
59
60
  weight,
@@ -67,14 +68,34 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
67
68
  compute_nll_loss=False,
68
69
  compiled=True,
69
70
  use_ref_model=True,
71
+ chunk_size=1,
70
72
  ):
71
- return LigerFusedLinearPreferenceBase.forward(
73
+ """
74
+ Fused linear layer with DPO loss.
75
+ Args:
76
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
77
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
78
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
79
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
80
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
81
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
82
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
83
+ ignore_index (int): Index to ignore in loss computation
84
+ beta (float): Weight for the odds ratio loss
85
+ compute_nll_loss (bool): Whether to compute the NLL loss
86
+ compiled (bool): Whether to use torch compile
87
+ use_ref_model (bool): Whether to use a reference model
88
+ chunk_size (int): Size of chunks for processing.
89
+ Returns:
90
+ torch.Tensor: Computed loss
91
+ """
92
+ return super().forward(
93
+ cls=cls,
72
94
  ctx=ctx,
73
95
  _input=_input,
74
96
  weight=weight,
75
97
  target=target,
76
98
  bias=bias,
77
- loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
78
99
  ignore_index=ignore_index,
79
100
  beta=beta,
80
101
  compute_nll_loss=compute_nll_loss,
@@ -83,12 +104,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
83
104
  ref_input=ref_input,
84
105
  ref_weight=ref_weight,
85
106
  ref_bias=ref_bias,
107
+ chunk_size=chunk_size,
86
108
  )
87
109
 
88
110
  @staticmethod
89
111
  def backward(ctx, *grad_output):
90
112
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
91
- return *grads, None, None, None, None, None, None, None, None
113
+ return *grads, None, None, None, None, None, None, None, None, None
92
114
 
93
115
 
94
116
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -103,6 +125,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
103
125
  compute_nll_loss: bool = False,
104
126
  compiled: bool = True,
105
127
  use_ref_model: bool = True,
128
+ chunk_size: int = 1,
106
129
  ):
107
130
  """
108
131
  Args:
@@ -111,6 +134,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
111
134
  compute_nll_loss (bool): Whether to compute the NLL loss.
112
135
  compiled (bool): Whether to use the torch compiled kernel.
113
136
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
137
+ chunk_size (int): Size of chunks for processing.
114
138
  """
115
139
  super().__init__()
116
140
  self.ignore_index = ignore_index
@@ -118,6 +142,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
118
142
  self.compute_nll_loss = compute_nll_loss
119
143
  self.compiled = compiled
120
144
  self.use_ref_model = use_ref_model
145
+ self.chunk_size = chunk_size
121
146
 
122
147
  def forward(
123
148
  self,
@@ -142,4 +167,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
167
  self.compute_nll_loss,
143
168
  self.compiled,
144
169
  self.use_ref_model,
170
+ self.chunk_size,
145
171
  )
@@ -125,6 +125,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
125
125
 
126
126
  @staticmethod
127
127
  def forward(
128
+ cls,
128
129
  ctx,
129
130
  student_input,
130
131
  student_weight,
@@ -133,7 +134,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
133
134
  target,
134
135
  student_bias=None,
135
136
  teacher_bias=None,
136
- loss_fn=None,
137
137
  chunk_size=1024,
138
138
  ignore_index=-100,
139
139
  weight_hard_loss=0.5,
@@ -175,7 +175,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
175
175
 
176
176
  loss_func_to_call = partial(
177
177
  LigerFusedLinearDistillationBase._compute_loss,
178
- distillation_loss_fn=loss_fn,
178
+ distillation_loss_fn=cls.distillation_loss_fn,
179
179
  full_target=target,
180
180
  ignore_index=ignore_index,
181
181
  weight_hard_loss=weight_hard_loss,
@@ -263,4 +263,4 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
263
263
  grad_weight = grad_weight * grad_output
264
264
  grad_bias = grad_bias * grad_output if grad_bias is not None else None
265
265
 
266
- return grad_input, grad_weight, None, grad_bias
266
+ return grad_input, grad_weight, None, None, None, grad_bias
@@ -16,12 +16,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
16
16
 
17
17
  @staticmethod
18
18
  def forward(
19
+ cls,
19
20
  ctx,
20
21
  _input,
21
22
  weight,
22
23
  target,
23
24
  bias=None,
24
- loss_fn=None,
25
25
  chunk_size=1,
26
26
  ignore_index=-100,
27
27
  alpha=1.0,
@@ -89,7 +89,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
89
89
 
90
90
  compute_loss = partial(
91
91
  LigerFusedLinearPreferenceBase._compute_loss,
92
- preference_loss_fn=loss_fn,
92
+ preference_loss_fn=cls.preference_loss_fn,
93
93
  ignore_index=ignore_index,
94
94
  alpha=alpha,
95
95
  beta=beta,
@@ -0,0 +1,240 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class LigerFusedLinearRLHFBase(torch.autograd.Function):
9
+ @abstractmethod
10
+ def rlhf_loss_fn(*args, **kwargs):
11
+ """
12
+ To be extended by subclasses.
13
+ """
14
+ raise NotImplementedError("RLHF loss function must be implemented.")
15
+
16
+ @staticmethod
17
+ def forward(
18
+ cls,
19
+ ctx,
20
+ _input,
21
+ weight,
22
+ attention_mask,
23
+ rewards,
24
+ bias=None,
25
+ num_generations=4,
26
+ beta=0.1,
27
+ compiled=True,
28
+ use_ref_model=False,
29
+ ref_input=None,
30
+ ref_weight=None,
31
+ ref_bias=None,
32
+ chunk_size=1,
33
+ ):
34
+ """Chunked forward pass for RLHF loss computation.
35
+
36
+ Args:
37
+ cls: The class
38
+ ctx: Context for backward
39
+ _input: Input tensor
40
+ weight: Weight tensor
41
+ attention_mask: Attention mask tensor
42
+ rewards: Rewards tensor
43
+ bias: Bias tensor
44
+ num_generations: Number of generations per prompt
45
+ beta: Weight for the KL penalty
46
+ compiled: Whether to use torch compile
47
+ use_ref_model: Whether to use a reference model
48
+ ref_input: Reference model input tensor
49
+ ref_weight: Reference model weight tensor
50
+ ref_bias: Reference model bias tensor
51
+ chunk_size: Size of chunks for processing in other loss modules
52
+ """
53
+ # Save for backward
54
+ ctx.beta = beta
55
+ ctx.rewards = rewards
56
+
57
+ # Initialize accumulators
58
+ loss_acc = torch.zeros((), device=_input.device)
59
+ grad_weight = torch.zeros_like(weight) # [V, H]
60
+ grad_inputs = []
61
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
62
+ aggregated_metrics = []
63
+
64
+ # Create a partial function with fixed arguments
65
+ compute_loss = partial(
66
+ LigerFusedLinearRLHFBase._compute_chunk_loss,
67
+ beta=beta,
68
+ use_ref_model=use_ref_model,
69
+ ref_weight=ref_weight,
70
+ ref_bias=ref_bias,
71
+ rlhf_loss_fn=cls.rlhf_loss_fn,
72
+ )
73
+
74
+ def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
75
+ """Fused forward and backward for a chunk."""
76
+ if bias is not None:
77
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
78
+ input_chunk, # arg 0
79
+ weight, # arg 1
80
+ attention_mask_chunk, # arg 2
81
+ rewards_chunk, # arg 3
82
+ ref_input_chunk, # arg 4
83
+ bias, # arg 5
84
+ )
85
+ else:
86
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
87
+ input_chunk, # arg 0
88
+ weight, # arg 1
89
+ attention_mask_chunk, # arg 2
90
+ rewards_chunk, # arg 3
91
+ ref_input_chunk, # arg 4
92
+ )
93
+
94
+ def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
95
+ if bias is not None:
96
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
97
+ input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
98
+ )
99
+ grad_bias.add_(chunk_grad_bias)
100
+ else:
101
+ (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
102
+ input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
103
+ )
104
+
105
+ # Accumulate gradients and loss
106
+ grad_weight.add_(chunk_grad_weight)
107
+ grad_inputs.append(chunk_grad_input)
108
+ loss_acc.add_(chunk_loss)
109
+
110
+ # Initialize storage for metrics on first chunk
111
+ if len(aggregated_metrics) == 0:
112
+ for metric in chunk_metrics:
113
+ if metric.ndim == 0:
114
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
115
+ else:
116
+ aggregated_metrics.append([])
117
+
118
+ # Accumulate metrics
119
+ for i, metric in enumerate(chunk_metrics):
120
+ if metric.ndim == 0:
121
+ aggregated_metrics[i].add_(metric)
122
+ else:
123
+ aggregated_metrics[i].append(metric)
124
+
125
+ if compiled:
126
+ accumulate_chunk = torch.compile(accumulate_chunk)
127
+
128
+ # Process input in chunks based on num_generations
129
+ chunks = max(1, _input.shape[0] // num_generations)
130
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
131
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
132
+ _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
133
+ _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
134
+
135
+ for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
136
+ _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
137
+ ):
138
+ # Mark dynamic dimensions
139
+ torch._dynamo.mark_dynamic(input_chunk, 1)
140
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
141
+ if ref_input_chunk is not None:
142
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
143
+
144
+ accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
145
+
146
+ # Scale accumulated loss by number of chunks since we're averaging
147
+ loss_acc = loss_acc / chunks
148
+
149
+ # Combine gradients
150
+ grad_input = torch.cat(grad_inputs, dim=0)
151
+
152
+ # Save for backward
153
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
154
+
155
+ # Finalize metrics
156
+ final_metrics = []
157
+ for metric in aggregated_metrics:
158
+ if isinstance(metric, list):
159
+ final_metrics.append(torch.cat(metric, dim=0))
160
+ else:
161
+ final_metrics.append(metric / chunks)
162
+
163
+ return loss_acc, tuple(final_metrics)
164
+
165
+ @staticmethod
166
+ def _compute_chunk_loss(
167
+ input_chunk,
168
+ weight,
169
+ attention_mask_chunk,
170
+ rewards_chunk,
171
+ ref_input_chunk=None,
172
+ bias=None,
173
+ beta=0.1,
174
+ use_ref_model=False,
175
+ ref_weight=None,
176
+ ref_bias=None,
177
+ rlhf_loss_fn=None,
178
+ ):
179
+ """Compute loss for a single chunk."""
180
+ # Get policy log probabilities using chunk_forward
181
+ log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
182
+
183
+ # Get reference log probabilities if needed
184
+ ref_log_probs = None
185
+ if use_ref_model and ref_input_chunk is not None:
186
+ with torch.no_grad():
187
+ ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
188
+
189
+ # Compute chunk loss and metrics using the provided loss function
190
+ chunk_loss, chunk_metrics = rlhf_loss_fn(
191
+ log_probs=log_probs,
192
+ attention_mask=attention_mask_chunk,
193
+ rewards=rewards_chunk,
194
+ ref_log_probs=ref_log_probs,
195
+ beta=beta,
196
+ )
197
+
198
+ return chunk_loss, (logits_mean, *chunk_metrics)
199
+
200
+ @staticmethod
201
+ def chunk_forward(input_chunk, weight, bias=None):
202
+ """Forward pass computation for a single chunk without explicit reshaping."""
203
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
204
+ logits = torch.matmul(input_chunk, weight.t())
205
+ if bias is not None:
206
+ logits = logits + bias # Broadcasts bias to [B, T, V]
207
+
208
+ # Compute log probabilities using softmax over the last dimension
209
+ log_probs = F.log_softmax(logits.float(), dim=-1)
210
+
211
+ # Monitoring: compute mean of logits
212
+ batch_size, seq_len, _ = input_chunk.shape
213
+ logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
214
+ return log_probs, logits, logits_mean
215
+
216
+ @staticmethod
217
+ def backward(ctx, grad_output, *grad_metrics):
218
+ """Backward pass for RLHF loss."""
219
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
220
+ if grad_output != 1.0:
221
+ grad_input = grad_input * grad_output
222
+ grad_weight = grad_weight * grad_output
223
+ if grad_bias is not None:
224
+ grad_bias = grad_bias * grad_output
225
+
226
+ return (
227
+ grad_input,
228
+ grad_weight,
229
+ None, # grad_attention_mask
230
+ None, # grad_rewards
231
+ grad_bias,
232
+ None, # grad_num_generations
233
+ None, # grad_beta
234
+ None, # grad_compiled
235
+ None, # grad_use_ref_model
236
+ None, # grad_ref_input
237
+ None, # grad_ref_weight
238
+ None, # grad_ref_bias
239
+ None, # grad_chunk_size
240
+ )