liger-kernel 0.5.4__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,8 +39,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
39
39
 
40
40
  return loss, chosen_rewards, rejected_rewards
41
41
 
42
- @staticmethod
42
+ @classmethod
43
43
  def forward(
44
+ cls,
44
45
  ctx,
45
46
  _input,
46
47
  weight,
@@ -52,27 +53,48 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
52
53
  label_smoothing=0.0,
53
54
  compute_nll_loss=True,
54
55
  compiled=True,
56
+ average_log_prob=False,
57
+ chunk_size=1,
55
58
  ):
56
- return LigerFusedLinearPreferenceBase.forward(
57
- ctx,
58
- _input,
59
- weight,
60
- target,
61
- bias,
62
- loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
59
+ """
60
+ Fused linear layer with CPO loss.
61
+ Args:
62
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
63
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
64
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
65
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
66
+ ignore_index (int): Index to ignore in loss computation
67
+ beta (float): Weight for the odds ratio loss
68
+ alpha (float): Weight for the alpha parameter
69
+ label_smoothing (float): Label smoothing factor
70
+ compute_nll_loss (bool): Whether to compute the NLL loss
71
+ compiled (bool): Whether to use torch compile
72
+ average_log_prob (bool): Whether to average the log probability per non-masked token
73
+ chunk_size (int): Size of chunks for processing.
74
+ Returns:
75
+ torch.Tensor: Computed loss
76
+ """
77
+ return super().forward(
78
+ cls=cls,
79
+ ctx=ctx,
80
+ _input=_input,
81
+ weight=weight,
82
+ target=target,
83
+ bias=bias,
63
84
  ignore_index=ignore_index,
64
85
  alpha=alpha,
65
86
  beta=beta,
66
87
  label_smoothing=label_smoothing,
67
88
  compute_nll_loss=compute_nll_loss,
68
- average_log_prob=False,
89
+ average_log_prob=average_log_prob,
69
90
  compiled=compiled,
91
+ chunk_size=chunk_size,
70
92
  )
71
93
 
72
94
  @staticmethod
73
95
  def backward(ctx, *grad_output):
74
96
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
75
- return *grads, None, None, None, None, None, None
97
+ return *grads, None, None, None, None, None, None, None, None
76
98
 
77
99
 
78
100
  class LigerFusedLinearCPOLoss(torch.nn.Module):
@@ -88,11 +110,19 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
88
110
  label_smoothing: float = 0.0,
89
111
  compute_nll_loss: bool = True,
90
112
  compiled: bool = True,
113
+ average_log_prob: bool = False,
114
+ chunk_size: int = 1,
91
115
  ):
92
116
  """
93
117
  Args:
94
118
  ignore_index (int): Index to ignore in the loss.
95
119
  beta (float): Weight for the odds ratio loss.
120
+ alpha (float): Weight for the alpha parameter.
121
+ label_smoothing (float): Label smoothing factor.
122
+ compute_nll_loss (bool): Whether to compute the NLL loss.
123
+ compiled (bool): Whether to use the torch compiled kernel.
124
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
125
+ chunk_size (int): Size of chunks for processing.
96
126
  """
97
127
  super().__init__()
98
128
  self.ignore_index = ignore_index
@@ -101,8 +131,16 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
101
131
  self.label_smoothing = label_smoothing
102
132
  self.compute_nll_loss = compute_nll_loss
103
133
  self.compiled = compiled
134
+ self.average_log_prob = average_log_prob
135
+ self.chunk_size = chunk_size
104
136
 
105
- def forward(self, lin_weight, _input, target, bias=None):
137
+ def forward(
138
+ self,
139
+ lin_weight,
140
+ _input,
141
+ target,
142
+ bias=None,
143
+ ):
106
144
  return LigerFusedLinearCPOFunction.apply(
107
145
  _input,
108
146
  lin_weight,
@@ -114,4 +152,6 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
114
152
  self.label_smoothing,
115
153
  self.compute_nll_loss,
116
154
  self.compiled,
155
+ self.average_log_prob,
156
+ self.chunk_size,
117
157
  )
@@ -52,8 +52,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
52
52
  loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
53
53
  return loss, chosen_rewards, rejected_rewards
54
54
 
55
- @staticmethod
55
+ @classmethod
56
56
  def forward(
57
+ cls,
57
58
  ctx,
58
59
  _input,
59
60
  weight,
@@ -67,14 +68,34 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
67
68
  compute_nll_loss=False,
68
69
  compiled=True,
69
70
  use_ref_model=True,
71
+ chunk_size=1,
70
72
  ):
71
- return LigerFusedLinearPreferenceBase.forward(
73
+ """
74
+ Fused linear layer with DPO loss.
75
+ Args:
76
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
77
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
78
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
79
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
80
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
81
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
82
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
83
+ ignore_index (int): Index to ignore in loss computation
84
+ beta (float): Weight for the odds ratio loss
85
+ compute_nll_loss (bool): Whether to compute the NLL loss
86
+ compiled (bool): Whether to use torch compile
87
+ use_ref_model (bool): Whether to use a reference model
88
+ chunk_size (int): Size of chunks for processing.
89
+ Returns:
90
+ torch.Tensor: Computed loss
91
+ """
92
+ return super().forward(
93
+ cls=cls,
72
94
  ctx=ctx,
73
95
  _input=_input,
74
96
  weight=weight,
75
97
  target=target,
76
98
  bias=bias,
77
- loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
78
99
  ignore_index=ignore_index,
79
100
  beta=beta,
80
101
  compute_nll_loss=compute_nll_loss,
@@ -83,12 +104,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
83
104
  ref_input=ref_input,
84
105
  ref_weight=ref_weight,
85
106
  ref_bias=ref_bias,
107
+ chunk_size=chunk_size,
86
108
  )
87
109
 
88
110
  @staticmethod
89
111
  def backward(ctx, *grad_output):
90
112
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
91
- return *grads, None, None, None, None, None, None, None, None
113
+ return *grads, None, None, None, None, None, None, None, None, None
92
114
 
93
115
 
94
116
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -103,6 +125,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
103
125
  compute_nll_loss: bool = False,
104
126
  compiled: bool = True,
105
127
  use_ref_model: bool = True,
128
+ chunk_size: int = 1,
106
129
  ):
107
130
  """
108
131
  Args:
@@ -111,6 +134,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
111
134
  compute_nll_loss (bool): Whether to compute the NLL loss.
112
135
  compiled (bool): Whether to use the torch compiled kernel.
113
136
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
137
+ chunk_size (int): Size of chunks for processing.
114
138
  """
115
139
  super().__init__()
116
140
  self.ignore_index = ignore_index
@@ -118,6 +142,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
118
142
  self.compute_nll_loss = compute_nll_loss
119
143
  self.compiled = compiled
120
144
  self.use_ref_model = use_ref_model
145
+ self.chunk_size = chunk_size
121
146
 
122
147
  def forward(
123
148
  self,
@@ -142,4 +167,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
167
  self.compute_nll_loss,
143
168
  self.compiled,
144
169
  self.use_ref_model,
170
+ self.chunk_size,
145
171
  )
@@ -125,6 +125,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
125
125
 
126
126
  @staticmethod
127
127
  def forward(
128
+ cls,
128
129
  ctx,
129
130
  student_input,
130
131
  student_weight,
@@ -133,7 +134,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
133
134
  target,
134
135
  student_bias=None,
135
136
  teacher_bias=None,
136
- loss_fn=None,
137
137
  chunk_size=1024,
138
138
  ignore_index=-100,
139
139
  weight_hard_loss=0.5,
@@ -175,7 +175,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
175
175
 
176
176
  loss_func_to_call = partial(
177
177
  LigerFusedLinearDistillationBase._compute_loss,
178
- distillation_loss_fn=loss_fn,
178
+ distillation_loss_fn=cls.distillation_loss_fn,
179
179
  full_target=target,
180
180
  ignore_index=ignore_index,
181
181
  weight_hard_loss=weight_hard_loss,
@@ -263,4 +263,4 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
263
263
  grad_weight = grad_weight * grad_output
264
264
  grad_bias = grad_bias * grad_output if grad_bias is not None else None
265
265
 
266
- return grad_input, grad_weight, None, grad_bias
266
+ return grad_input, grad_weight, None, None, None, grad_bias
@@ -16,12 +16,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
16
16
 
17
17
  @staticmethod
18
18
  def forward(
19
+ cls,
19
20
  ctx,
20
21
  _input,
21
22
  weight,
22
23
  target,
23
24
  bias=None,
24
- loss_fn=None,
25
25
  chunk_size=1,
26
26
  ignore_index=-100,
27
27
  alpha=1.0,
@@ -89,7 +89,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
89
89
 
90
90
  compute_loss = partial(
91
91
  LigerFusedLinearPreferenceBase._compute_loss,
92
- preference_loss_fn=loss_fn,
92
+ preference_loss_fn=cls.preference_loss_fn,
93
93
  ignore_index=ignore_index,
94
94
  alpha=alpha,
95
95
  beta=beta,
@@ -1,3 +1,4 @@
1
+ from abc import abstractmethod
1
2
  from functools import partial
2
3
 
3
4
  import torch
@@ -5,15 +6,22 @@ import torch.nn.functional as F
5
6
 
6
7
 
7
8
  class LigerFusedLinearRLHFBase(torch.autograd.Function):
9
+ @abstractmethod
10
+ def rlhf_loss_fn(*args, **kwargs):
11
+ """
12
+ To be extended by subclasses.
13
+ """
14
+ raise NotImplementedError("RLHF loss function must be implemented.")
15
+
8
16
  @staticmethod
9
17
  def forward(
18
+ cls,
10
19
  ctx,
11
20
  _input,
12
21
  weight,
13
22
  attention_mask,
14
23
  rewards,
15
24
  bias=None,
16
- loss_fn=None,
17
25
  num_generations=4,
18
26
  beta=0.1,
19
27
  compiled=True,
@@ -21,8 +29,27 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
21
29
  ref_input=None,
22
30
  ref_weight=None,
23
31
  ref_bias=None,
32
+ chunk_size=1,
24
33
  ):
25
- """Chunked forward pass for RLHF loss computation."""
34
+ """Chunked forward pass for RLHF loss computation.
35
+
36
+ Args:
37
+ cls: The class
38
+ ctx: Context for backward
39
+ _input: Input tensor
40
+ weight: Weight tensor
41
+ attention_mask: Attention mask tensor
42
+ rewards: Rewards tensor
43
+ bias: Bias tensor
44
+ num_generations: Number of generations per prompt
45
+ beta: Weight for the KL penalty
46
+ compiled: Whether to use torch compile
47
+ use_ref_model: Whether to use a reference model
48
+ ref_input: Reference model input tensor
49
+ ref_weight: Reference model weight tensor
50
+ ref_bias: Reference model bias tensor
51
+ chunk_size: Size of chunks for processing in other loss modules
52
+ """
26
53
  # Save for backward
27
54
  ctx.beta = beta
28
55
  ctx.rewards = rewards
@@ -41,7 +68,7 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
41
68
  use_ref_model=use_ref_model,
42
69
  ref_weight=ref_weight,
43
70
  ref_bias=ref_bias,
44
- rlhf_loss_fn=loss_fn,
71
+ rlhf_loss_fn=cls.rlhf_loss_fn,
45
72
  )
46
73
 
47
74
  def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
@@ -98,7 +125,7 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
98
125
  if compiled:
99
126
  accumulate_chunk = torch.compile(accumulate_chunk)
100
127
 
101
- # Process input in chunks
128
+ # Process input in chunks based on num_generations
102
129
  chunks = max(1, _input.shape[0] // num_generations)
103
130
  _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
104
131
  _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
@@ -202,12 +229,12 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
202
229
  None, # grad_attention_mask
203
230
  None, # grad_rewards
204
231
  grad_bias,
205
- None, # grad_loss_fn
206
- None, # grad_chunk_size
232
+ None, # grad_num_generations
207
233
  None, # grad_beta
208
234
  None, # grad_compiled
209
235
  None, # grad_use_ref_model
210
236
  None, # grad_ref_input
211
237
  None, # grad_ref_weight
212
238
  None, # grad_ref_bias
239
+ None, # grad_chunk_size
213
240
  )
@@ -16,13 +16,13 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
16
16
 
17
17
  @staticmethod
18
18
  def forward(
19
+ cls,
19
20
  ctx,
20
21
  _input,
21
22
  weight,
22
23
  target,
23
24
  preference_labels,
24
25
  bias=None,
25
- loss_fn=None,
26
26
  chunk_size=1,
27
27
  ignore_index=-100,
28
28
  compiled=True,
@@ -30,6 +30,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
30
30
  ref_input=None,
31
31
  ref_weight=None,
32
32
  ref_bias=None,
33
+ average_log_prob=False,
33
34
  **loss_kwargs,
34
35
  ):
35
36
  """
@@ -59,6 +60,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
59
60
  Shape: (batch_size,).
60
61
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
61
62
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
63
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
62
64
  loss_kwargs (dict): Other possible arguments that a loss function might need
63
65
  """
64
66
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -72,14 +74,22 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
72
74
  # Loss to be accumulated
73
75
  loss_acc = torch.zeros((), device=_input.device)
74
76
 
77
+ # Metrics to be recorded
78
+ chosen_logps_sum = torch.zeros((), device=_input.device)
79
+ rejected_logps_sum = torch.zeros((), device=_input.device)
80
+ chosen_logits_sum = torch.zeros((), device=_input.device)
81
+ rejected_logits_sum = torch.zeros((), device=_input.device)
82
+ aggregated_aux_outputs = []
83
+
75
84
  compute_loss = partial(
76
85
  LigerFusedLinearUnpairedPreferenceBase._compute_loss,
77
- preference_loss_fn=loss_fn,
86
+ preference_loss_fn=cls.preference_loss_fn,
78
87
  full_target=target,
79
88
  ignore_index=ignore_index,
80
89
  use_ref_model=use_ref_model,
81
90
  ref_weight=ref_weight,
82
91
  ref_bias=ref_bias,
92
+ average_log_prob=average_log_prob,
83
93
  **loss_kwargs,
84
94
  )
85
95
 
@@ -88,7 +98,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
88
98
  Fused forward and backward pass for a chunk of input and target.
89
99
  """
90
100
  argnums = (0, 1, 4) if bias is not None else (0, 1)
91
- return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=False)(
101
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
92
102
  input_chunk,
93
103
  weight,
94
104
  target_chunk,
@@ -103,9 +113,19 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
103
113
  preference_labels_chunk=None,
104
114
  ref_input_chunk=None,
105
115
  ):
106
- (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd(
107
- input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk
108
- )
116
+ (
117
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias),
118
+ (
119
+ chunk_loss,
120
+ (
121
+ chunk_chosen_logps_sum,
122
+ chunk_rejected_logps_sum,
123
+ chunk_chosen_logits_sum,
124
+ chunk_rejected_logits_sum,
125
+ *aux_outputs,
126
+ ),
127
+ ),
128
+ ) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
109
129
  if bias is not None:
110
130
  grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
111
131
 
@@ -116,6 +136,23 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
116
136
  # Accumulate loss
117
137
  loss_acc.add_(chunk_loss)
118
138
 
139
+ # Accumulate metrics
140
+ chosen_logps_sum.add_(chunk_chosen_logps_sum)
141
+ rejected_logps_sum.add_(chunk_rejected_logps_sum)
142
+ chosen_logits_sum.add_(chunk_chosen_logits_sum)
143
+ rejected_logits_sum.add_(chunk_rejected_logits_sum)
144
+
145
+ # aux_outputs
146
+ # Initialize storage for aux_outputs
147
+ if len(aggregated_aux_outputs) == 0:
148
+ for aux in aux_outputs:
149
+ aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
150
+
151
+ # Process each aux_output
152
+ for i, aux in enumerate(aux_outputs):
153
+ if aux.ndim == 0:
154
+ aggregated_aux_outputs[i].add_(aux)
155
+
119
156
  if compiled:
120
157
  fused_fwd_bwd = torch.compile(fused_fwd_bwd)
121
158
 
@@ -151,12 +188,25 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
151
188
  # accumulate loss, gradients, and metrics
152
189
  accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
153
190
 
191
+ # Aggregate aux outputs lists into tensors
192
+ for i, aux in enumerate(aggregated_aux_outputs):
193
+ if isinstance(aux, list):
194
+ aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
195
+
154
196
  ctx.save_for_backward(
155
197
  torch.cat(grad_inputs, dim=0),
156
198
  grad_weight,
157
199
  grad_bias,
158
200
  )
159
- return loss_acc
201
+
202
+ return_vars = (
203
+ chosen_logps_sum,
204
+ rejected_logps_sum,
205
+ chosen_logits_sum,
206
+ rejected_logits_sum,
207
+ )
208
+
209
+ return loss_acc, (*return_vars, *aggregated_aux_outputs)
160
210
 
161
211
  @staticmethod
162
212
  def backward(ctx, *grad_output):
@@ -173,21 +223,37 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
173
223
  input_chunk,
174
224
  weight,
175
225
  target_chunk,
226
+ preference_labels_chunk,
176
227
  bias=None,
177
228
  ignore_index=-100,
229
+ average_log_prob=False,
178
230
  ):
179
231
  logits_chunk = input_chunk @ weight.t()
180
232
  if bias is not None:
181
233
  logits_chunk = logits_chunk + bias
182
234
  log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
183
-
184
235
  loss_mask_chunk = target_chunk != ignore_index
185
236
  label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
186
237
 
187
238
  per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
188
- average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
189
-
190
- return average_log_prob_chunk
239
+ if average_log_prob:
240
+ log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
241
+ else:
242
+ log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1)
243
+
244
+ chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum()
245
+ rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum()
246
+
247
+ chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum()
248
+ rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum()
249
+
250
+ return (
251
+ log_probs,
252
+ chosen_logps_sum,
253
+ rejected_logps_sum,
254
+ chosen_logits_sum,
255
+ rejected_logits_sum,
256
+ )
191
257
 
192
258
  @staticmethod
193
259
  def _compute_loss(
@@ -203,6 +269,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
203
269
  ref_input_chunk=None,
204
270
  ref_weight=None,
205
271
  ref_bias=None,
272
+ average_log_prob=False,
206
273
  **loss_kwargs,
207
274
  ):
208
275
  """
@@ -218,29 +285,57 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
218
285
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
219
286
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
220
287
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
288
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
221
289
  loss_kwargs (dict): Additional arguments for the loss function.
222
290
  """
223
- average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
291
+ (
292
+ log_prob_chunk,
293
+ chosen_logps_sum,
294
+ rejected_logps_sum,
295
+ chosen_logits_sum,
296
+ rejected_logits_sum,
297
+ ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
224
298
  input_chunk,
225
299
  weight,
226
300
  target_chunk,
301
+ preference_labels_chunk,
227
302
  bias=bias,
228
303
  ignore_index=ignore_index,
304
+ average_log_prob=average_log_prob,
229
305
  )
230
306
 
231
307
  if use_ref_model:
232
308
  with torch.no_grad():
233
- ref_average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
309
+ (
310
+ ref_log_prob_chunk,
311
+ _,
312
+ _,
313
+ _,
314
+ _,
315
+ ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
234
316
  ref_input_chunk,
235
317
  ref_weight,
236
318
  target_chunk,
319
+ preference_labels_chunk,
237
320
  ref_bias,
238
321
  ignore_index=ignore_index,
322
+ average_log_prob=average_log_prob,
239
323
  )
240
- loss_kwargs["ref_average_log_prob_chunk"] = ref_average_log_prob_chunk
324
+ loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk
241
325
 
242
- preference_loss_chunk = preference_loss_fn(
243
- average_log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
326
+ preference_loss_outputs = preference_loss_fn(
327
+ log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
328
+ )
329
+ if isinstance(preference_loss_outputs, tuple):
330
+ preference_loss_chunk, *aux_outputs = preference_loss_outputs
331
+ else:
332
+ preference_loss_chunk, aux_outputs = preference_loss_outputs, []
333
+
334
+ return_vars = (
335
+ chosen_logps_sum,
336
+ rejected_logps_sum,
337
+ chosen_logits_sum,
338
+ rejected_logits_sum,
244
339
  )
245
340
 
246
- return preference_loss_chunk
341
+ return preference_loss_chunk, (*return_vars, *aux_outputs)
@@ -63,8 +63,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
63
63
 
64
64
  return loss, metrics
65
65
 
66
- @staticmethod
66
+ @classmethod
67
67
  def forward(
68
+ cls,
68
69
  ctx,
69
70
  _input,
70
71
  weight,
@@ -78,13 +79,33 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
78
79
  compiled=True,
79
80
  use_ref_model=True,
80
81
  num_generations=1,
82
+ chunk_size=1,
81
83
  ):
82
- return LigerFusedLinearRLHFBase.forward(
84
+ """
85
+ Fused linear layer with GRPO loss.
86
+ Args:
87
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
88
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
89
+ attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
90
+ rewards (torch.Tensor): Rewards tensor. Shape: (batch_size,)
91
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
92
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
93
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
94
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
95
+ beta (float): Weight for the KL penalty
96
+ compiled (bool): Whether to use torch compile
97
+ use_ref_model (bool): Whether to use a reference model
98
+ num_generations (int): Number of generations per prompt
99
+ chunk_size (int): Size of chunks for processing.
100
+ Returns:
101
+ torch.Tensor: Computed loss
102
+ """
103
+ return super().forward(
104
+ cls=cls,
83
105
  ctx=ctx,
84
106
  _input=_input,
85
107
  weight=weight,
86
108
  attention_mask=attention_mask,
87
- loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn,
88
109
  rewards=rewards,
89
110
  bias=bias,
90
111
  ref_input=ref_input,
@@ -94,6 +115,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
94
115
  compiled=compiled,
95
116
  use_ref_model=use_ref_model,
96
117
  num_generations=num_generations,
118
+ chunk_size=chunk_size,
97
119
  )
98
120
 
99
121
  @staticmethod
@@ -114,6 +136,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
114
136
  None, # grad_compiled
115
137
  None, # grad_use_ref_model
116
138
  None, # grad_num_generations
139
+ None, # grad_chunk_size
117
140
  )
118
141
 
119
142
 
@@ -126,12 +149,22 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
126
149
  compiled: bool = True,
127
150
  use_ref_model: bool = True,
128
151
  num_generations: int = 1,
152
+ chunk_size: int = 1,
129
153
  ):
154
+ """
155
+ Args:
156
+ beta (float): Weight for the KL penalty.
157
+ compiled (bool): Whether to use torch compile.
158
+ use_ref_model (bool): Whether to use a reference model.
159
+ num_generations (int): Number of generations per prompt.
160
+ chunk_size (int): Size of chunks for processing.
161
+ """
130
162
  super().__init__()
131
163
  self.beta = beta
132
164
  self.compiled = compiled
133
165
  self.use_ref_model = use_ref_model
134
166
  self.num_generations = num_generations
167
+ self.chunk_size = chunk_size
135
168
 
136
169
  def forward(
137
170
  self,
@@ -157,4 +190,5 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
157
190
  self.compiled,
158
191
  self.use_ref_model,
159
192
  self.num_generations,
193
+ self.chunk_size,
160
194
  )