liger-kernel 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +2 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  9. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  10. liger_kernel/chunked_loss/kto_loss.py +172 -0
  11. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  12. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  13. liger_kernel/env_report.py +5 -12
  14. liger_kernel/ops/cross_entropy.py +102 -51
  15. liger_kernel/ops/experimental/embedding.py +1 -3
  16. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  17. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  18. liger_kernel/ops/fused_linear_jsd.py +11 -29
  19. liger_kernel/ops/geglu.py +6 -17
  20. liger_kernel/ops/group_norm.py +11 -28
  21. liger_kernel/ops/jsd.py +2 -6
  22. liger_kernel/ops/kl_div.py +8 -11
  23. liger_kernel/ops/layer_norm.py +3 -5
  24. liger_kernel/ops/qwen2vl_mrope.py +21 -37
  25. liger_kernel/ops/rms_norm.py +14 -32
  26. liger_kernel/ops/rope.py +31 -33
  27. liger_kernel/ops/swiglu.py +4 -8
  28. liger_kernel/ops/utils.py +2 -0
  29. liger_kernel/transformers/__init__.py +16 -24
  30. liger_kernel/transformers/auto_model.py +6 -13
  31. liger_kernel/transformers/cross_entropy.py +4 -6
  32. liger_kernel/transformers/experimental/embedding.py +1 -3
  33. liger_kernel/transformers/functional.py +11 -7
  34. liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
  35. liger_kernel/transformers/geglu.py +1 -4
  36. liger_kernel/transformers/group_norm.py +3 -9
  37. liger_kernel/transformers/jsd.py +1 -3
  38. liger_kernel/transformers/kl_div.py +1 -3
  39. liger_kernel/transformers/layer_norm.py +3 -9
  40. liger_kernel/transformers/model/gemma.py +18 -40
  41. liger_kernel/transformers/model/gemma2.py +19 -41
  42. liger_kernel/transformers/model/llama.py +22 -48
  43. liger_kernel/transformers/model/mistral.py +14 -26
  44. liger_kernel/transformers/model/mixtral.py +24 -54
  45. liger_kernel/transformers/model/mllama.py +16 -36
  46. liger_kernel/transformers/model/phi3.py +18 -40
  47. liger_kernel/transformers/model/qwen2.py +18 -40
  48. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  49. liger_kernel/transformers/monkey_patch.py +43 -117
  50. liger_kernel/transformers/qwen2vl_mrope.py +2 -2
  51. liger_kernel/transformers/rms_norm.py +4 -4
  52. liger_kernel/transformers/rope.py +2 -2
  53. liger_kernel/transformers/swiglu.py +2 -8
  54. liger_kernel/transformers/trainer/__init__.py +1 -3
  55. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  56. liger_kernel/triton/__init__.py +1 -3
  57. liger_kernel/triton/monkey_patch.py +1 -3
  58. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
  59. liger_kernel-0.5.3.dist-info/RECORD +69 -0
  60. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
  61. liger_kernel-0.5.1.dist-info/RECORD +0 -65
  62. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
  63. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
  64. {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
@@ -2,11 +2,11 @@ from abc import abstractmethod
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
+
5
6
  from torch.nn import functional as F
6
7
 
7
8
 
8
9
  class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9
-
10
10
  @abstractmethod
11
11
  def preference_loss_fn(*args, **kwargs):
12
12
  """
@@ -27,11 +27,13 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
27
27
  alpha=1.0,
28
28
  beta=0.1,
29
29
  compute_nll_loss=True,
30
+ nll_target=None,
30
31
  compiled=True,
31
32
  use_ref_model=False,
32
- # TODO: ref input
33
+ ref_input=None,
33
34
  ref_weight=None,
34
35
  ref_bias=None,
36
+ average_log_prob=True,
35
37
  **loss_kwargs,
36
38
  ):
37
39
  """
@@ -57,10 +59,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
57
59
  alpha (float): Weight for the NLL loss.
58
60
  beta (float): Weight for the preference loss.
59
61
  compute_nll_loss (bool): Whether to compute NLL loss.
62
+ nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
60
63
  compiled (bool): Whether to use torch compile for chunk accumulation.
61
64
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
65
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
66
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
67
+ average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
64
68
  loss_kwargs (dict): Other possible arguments that a loss function might need
65
69
  """
66
70
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -94,55 +98,70 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
94
98
  use_ref_model=use_ref_model,
95
99
  ref_weight=ref_weight,
96
100
  ref_bias=ref_bias,
101
+ full_nll_target=nll_target,
102
+ average_log_prob=average_log_prob,
97
103
  **loss_kwargs,
98
104
  )
99
105
 
100
- def fused_fwd_bwd(input_chunk, target_chunk):
106
+ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
101
107
  """
102
108
  Fused forward and backward pass for a chunk of input and target.
103
109
  """
104
110
  if bias is not None:
105
- return torch.func.grad_and_value(
106
- compute_loss, argnums=(0, 1, 3), has_aux=True
107
- )(input_chunk, weight, target_chunk, bias)
111
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
112
+ input_chunk,
113
+ weight,
114
+ target_chunk,
115
+ bias,
116
+ ref_input_chunk=ref_input_chunk,
117
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
118
+ )
108
119
  else:
109
- return torch.func.grad_and_value(
110
- compute_loss, argnums=(0, 1), has_aux=True
111
- )(input_chunk, weight, target_chunk)
120
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
121
+ input_chunk,
122
+ weight,
123
+ target_chunk,
124
+ ref_input_chunk=ref_input_chunk,
125
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
126
+ )
112
127
 
113
- def accumulate_chunk(input_chunk, target_chunk):
128
+ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
114
129
  if bias is not None:
115
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
116
- chunk_loss,
130
+ (
131
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
117
132
  (
118
- chunk_chosen_logps,
119
- chunk_rejected_logps,
120
- chunk_chosen_logits_mean,
121
- chunk_rejected_logits_mean,
122
- chunk_nll_loss,
123
- *aux_outputs,
133
+ chunk_loss,
134
+ (
135
+ chunk_chosen_logps,
136
+ chunk_rejected_logps,
137
+ chunk_chosen_logits_mean,
138
+ chunk_rejected_logits_mean,
139
+ chunk_nll_loss,
140
+ *aux_outputs,
141
+ ),
124
142
  ),
125
- ) = fused_fwd_bwd(input_chunk, target_chunk)
143
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
126
144
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
127
145
  else:
128
- (chunk_grad_input, chunk_grad_weight), (
129
- chunk_loss,
146
+ (
147
+ (chunk_grad_input, chunk_grad_weight),
130
148
  (
131
- chunk_chosen_logps,
132
- chunk_rejected_logps,
133
- chunk_chosen_logits_mean,
134
- chunk_rejected_logits_mean,
135
- chunk_nll_loss,
136
- *aux_outputs,
149
+ chunk_loss,
150
+ (
151
+ chunk_chosen_logps,
152
+ chunk_rejected_logps,
153
+ chunk_chosen_logits_mean,
154
+ chunk_rejected_logits_mean,
155
+ chunk_nll_loss,
156
+ *aux_outputs,
157
+ ),
137
158
  ),
138
- ) = fused_fwd_bwd(input_chunk, target_chunk)
159
+ ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
139
160
 
140
161
  # Accumulate gradients
141
162
  grad_weight.add_(chunk_grad_weight)
142
163
  grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
143
- grad_rejected_inputs.append(
144
- chunk_grad_input[chosen_target_chunk.shape[0] :]
145
- )
164
+ grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
146
165
 
147
166
  # Accumulate loss
148
167
  loss_acc.add_(chunk_loss)
@@ -159,9 +178,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
159
178
  if len(aggregated_aux_outputs) == 0:
160
179
  for aux in aux_outputs:
161
180
  if aux.ndim == 0:
162
- aggregated_aux_outputs.append(
163
- torch.zeros((), device=aux.device)
164
- )
181
+ aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
165
182
  else:
166
183
  aggregated_aux_outputs.append([])
167
184
 
@@ -182,29 +199,46 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
182
199
  _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
183
200
  _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
184
201
 
202
+ if nll_target is not None:
203
+ _chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
204
+
205
+ if use_ref_model:
206
+ _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
207
+ _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
208
+
185
209
  for (
186
210
  chosen_input_chunk,
187
211
  rejected_input_chunk,
188
212
  chosen_target_chunk,
189
213
  rejected_target_chunk,
214
+ ref_chosen_input_chunk,
215
+ ref_rejected_input_chunk,
216
+ chosen_nll_target_chunk,
190
217
  ) in zip(
191
218
  _chosen_input_chunks,
192
219
  _rejected_input_chunks,
193
220
  _chosen_target_chunks,
194
221
  _rejected_target_chunks,
222
+ (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
223
+ (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
+ (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
225
+ strict=False,
195
226
  ):
196
227
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
197
- target_chunk = torch.cat(
198
- [chosen_target_chunk, rejected_target_chunk], dim=0
228
+ ref_input_chunk = (
229
+ torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
199
230
  )
231
+ target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
200
232
 
201
233
  # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
202
234
  torch._dynamo.mark_dynamic(input_chunk, 1)
203
235
  torch._dynamo.mark_dynamic(target_chunk, 1)
204
236
  torch._dynamo.mark_dynamic(target, 1)
237
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
238
+ torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
205
239
 
206
240
  # accumulate loss, gradients, and metrics
207
- accumulate_chunk(input_chunk, target_chunk)
241
+ accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
208
242
 
209
243
  # combine grad_chosen_inputs and grad_rejected_inputs
210
244
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
@@ -233,14 +267,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
233
267
  @staticmethod
234
268
  def backward(ctx, *grad_output):
235
269
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
236
- if torch.ne(
237
- grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
238
- ):
270
+ if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
239
271
  grad_input = grad_input * grad_output[0][0]
240
272
  grad_weight = grad_weight * grad_output[0][0]
241
273
  grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
242
274
 
243
- return grad_input, grad_weight, None, grad_bias, None, None, None
275
+ return grad_input, grad_weight, None, grad_bias, None, None, None, None
244
276
 
245
277
  @staticmethod
246
278
  def chunk_forward(
@@ -250,6 +282,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
250
282
  bias=None,
251
283
  ignore_index=-100,
252
284
  compute_nll_loss=True,
285
+ chosen_nll_target_chunk=None,
286
+ average_log_prob=True,
253
287
  ):
254
288
  len_chosen_chunk = target_chunk.shape[0] // 2
255
289
  logits_chunk = input_chunk @ weight.t()
@@ -259,9 +293,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
259
293
 
260
294
  chosen_nll_loss = 0.0
261
295
  if compute_nll_loss:
296
+ nll_labels = (
297
+ chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
298
+ )
262
299
  chosen_nll_loss = F.nll_loss(
263
300
  log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
264
- target_chunk[:len_chosen_chunk].view(-1),
301
+ nll_labels.view(-1),
265
302
  reduction="sum",
266
303
  ignore_index=ignore_index,
267
304
  )
@@ -269,13 +306,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
269
306
  loss_mask = target_chunk != ignore_index
270
307
  label_chunk = torch.where(loss_mask, target_chunk, 0)
271
308
 
272
- per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
273
- -1
274
- )
275
- average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
309
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
310
+ if average_log_prob:
311
+ log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
312
+ else:
313
+ log_prob = (per_token_logps * loss_mask).sum(-1)
276
314
 
277
- chosen_logps = average_log_prob[:len_chosen_chunk]
278
- rejected_logps = average_log_prob[len_chosen_chunk:]
315
+ chosen_logps = log_prob[:len_chosen_chunk]
316
+ rejected_logps = log_prob[len_chosen_chunk:]
279
317
 
280
318
  chosen_logits = logits_chunk[:len_chosen_chunk]
281
319
  rejected_logits = logits_chunk[len_chosen_chunk:]
@@ -301,8 +339,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
301
339
  beta=0.1,
302
340
  compute_nll_loss=True,
303
341
  use_ref_model=False,
342
+ ref_input_chunk=None,
304
343
  ref_weight=None,
305
344
  ref_bias=None,
345
+ full_nll_target=None,
346
+ chosen_nll_target_chunk=None,
347
+ average_log_prob=True,
306
348
  **loss_kwargs,
307
349
  ):
308
350
  """
@@ -321,6 +363,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
321
363
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
322
364
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
323
365
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
366
+ full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
367
+ chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
368
+ average_log_prob (bool): Whether to average log probabilities or the sum.
324
369
  loss_kwargs (dict): Additional arguments for the loss function.
325
370
  """
326
371
  (
@@ -336,14 +381,15 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
336
381
  bias=bias,
337
382
  ignore_index=ignore_index,
338
383
  compute_nll_loss=compute_nll_loss,
384
+ chosen_nll_target_chunk=chosen_nll_target_chunk,
385
+ average_log_prob=average_log_prob,
339
386
  )
340
- chosen_nll_loss = (
341
- chosen_nll_loss
342
- / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
343
- )
344
- chosen_logits_mean = chosen_logits.sum() / (
345
- full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
346
- )
387
+ if full_nll_target is not None:
388
+ chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
389
+ else:
390
+ chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
391
+
392
+ chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
347
393
  rejected_logits_mean = rejected_logits.sum() / (
348
394
  full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
349
395
  )
@@ -353,16 +399,18 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
353
399
  (
354
400
  ref_chosen_logps,
355
401
  ref_rejected_logps,
356
- ref_chosen_logits,
357
- ref_rejected_logits,
358
- ref_chosen_nll_loss,
402
+ _,
403
+ _,
404
+ _,
359
405
  ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
- input_chunk,
406
+ ref_input_chunk,
361
407
  ref_weight,
362
408
  target_chunk,
363
409
  ref_bias,
364
410
  ignore_index=ignore_index,
365
411
  compute_nll_loss=False, # We don't need NLL loss for the reference model
412
+ chosen_nll_target_chunk=None,
413
+ average_log_prob=average_log_prob,
366
414
  )
367
415
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
368
416
  loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
@@ -375,7 +423,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
375
423
  else:
376
424
  preference_loss, aux_outputs = preference_loss_outputs, []
377
425
 
378
- loss = alpha * chosen_nll_loss - preference_loss
426
+ loss = alpha * chosen_nll_loss + preference_loss
379
427
  return_vars = (
380
428
  chosen_logps,
381
429
  rejected_logps,
@@ -0,0 +1,246 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def preference_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("Preference loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ ctx,
20
+ _input,
21
+ weight,
22
+ target,
23
+ preference_labels,
24
+ bias=None,
25
+ loss_fn=None,
26
+ chunk_size=1,
27
+ ignore_index=-100,
28
+ compiled=True,
29
+ use_ref_model=False,
30
+ ref_input=None,
31
+ ref_weight=None,
32
+ ref_bias=None,
33
+ **loss_kwargs,
34
+ ):
35
+ """
36
+ Base class for fused linear layer with unpaired preference loss like KTO
37
+ Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
38
+
39
+ The mental model is:
40
+
41
+ forward()
42
+ ├── Loop over chunks
43
+ └── compute_loss()
44
+ ├── chunk_forward() # Compute logits and log probs
45
+ └── prefer_loss() # Calculate preference loss
46
+
47
+ Args:
48
+ _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
49
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
50
+ target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
51
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
52
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
53
+ chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
54
+ ignore_index (int): Index to ignore for loss computation.
55
+ beta (float): Weight for the preference loss.
56
+ compiled (bool): Whether to use torch compile for chunk accumulation.
57
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
58
+ preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
59
+ Shape: (batch_size,).
60
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
61
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
62
+ loss_kwargs (dict): Other possible arguments that a loss function might need
63
+ """
64
+ # TODO: Tune CHUNK_SIZE to fully utilize the GPU
65
+ CHUNK_SIZE = chunk_size
66
+
67
+ # Gradients to be accumulated
68
+ grad_inputs = []
69
+ grad_weight = torch.zeros_like(weight)
70
+ grad_bias = torch.zeros_like(bias) if bias is not None else None
71
+
72
+ # Loss to be accumulated
73
+ loss_acc = torch.zeros((), device=_input.device)
74
+
75
+ compute_loss = partial(
76
+ LigerFusedLinearUnpairedPreferenceBase._compute_loss,
77
+ preference_loss_fn=loss_fn,
78
+ full_target=target,
79
+ ignore_index=ignore_index,
80
+ use_ref_model=use_ref_model,
81
+ ref_weight=ref_weight,
82
+ ref_bias=ref_bias,
83
+ **loss_kwargs,
84
+ )
85
+
86
+ def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
87
+ """
88
+ Fused forward and backward pass for a chunk of input and target.
89
+ """
90
+ argnums = (0, 1, 4) if bias is not None else (0, 1)
91
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=False)(
92
+ input_chunk,
93
+ weight,
94
+ target_chunk,
95
+ preference_labels_chunk,
96
+ bias,
97
+ ref_input_chunk=ref_input_chunk,
98
+ )
99
+
100
+ def accumulate_chunk(
101
+ input_chunk,
102
+ target_chunk,
103
+ preference_labels_chunk=None,
104
+ ref_input_chunk=None,
105
+ ):
106
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd(
107
+ input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk
108
+ )
109
+ if bias is not None:
110
+ grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
111
+
112
+ # Accumulate gradients
113
+ grad_weight.add_(chunk_grad_weight)
114
+ grad_inputs.append(chunk_grad_input)
115
+
116
+ # Accumulate loss
117
+ loss_acc.add_(chunk_loss)
118
+
119
+ if compiled:
120
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
121
+
122
+ # When not paired, use labels to separate chosen and rejected
123
+ assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
124
+
125
+ chunks = max(1, _input.shape[0] // CHUNK_SIZE)
126
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
127
+ _target_chunks = torch.chunk(target, chunks=chunks, dim=0)
128
+ _preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
129
+
130
+ if use_ref_model:
131
+ _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
132
+
133
+ for (
134
+ input_chunk,
135
+ target_chunk,
136
+ ref_input_chunk,
137
+ preference_labels_chunk,
138
+ ) in zip(
139
+ _input_chunks,
140
+ _target_chunks,
141
+ (_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
142
+ _preference_labels_chunks,
143
+ ):
144
+ # mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
145
+ torch._dynamo.mark_dynamic(input_chunk, 1)
146
+ torch._dynamo.mark_dynamic(target_chunk, 1)
147
+ torch._dynamo.mark_dynamic(target, 1)
148
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
149
+ torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
150
+
151
+ # accumulate loss, gradients, and metrics
152
+ accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
153
+
154
+ ctx.save_for_backward(
155
+ torch.cat(grad_inputs, dim=0),
156
+ grad_weight,
157
+ grad_bias,
158
+ )
159
+ return loss_acc
160
+
161
+ @staticmethod
162
+ def backward(ctx, *grad_output):
163
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
164
+ if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
165
+ grad_input = grad_input * grad_output[0][0]
166
+ grad_weight = grad_weight * grad_output[0][0]
167
+ grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
168
+
169
+ return grad_input, grad_weight, None, None, grad_bias
170
+
171
+ @staticmethod
172
+ def chunk_forward(
173
+ input_chunk,
174
+ weight,
175
+ target_chunk,
176
+ bias=None,
177
+ ignore_index=-100,
178
+ ):
179
+ logits_chunk = input_chunk @ weight.t()
180
+ if bias is not None:
181
+ logits_chunk = logits_chunk + bias
182
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
183
+
184
+ loss_mask_chunk = target_chunk != ignore_index
185
+ label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
186
+
187
+ per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
188
+ average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
189
+
190
+ return average_log_prob_chunk
191
+
192
+ @staticmethod
193
+ def _compute_loss(
194
+ input_chunk,
195
+ weight,
196
+ target_chunk,
197
+ preference_labels_chunk,
198
+ bias=None,
199
+ preference_loss_fn=None,
200
+ full_target=None,
201
+ ignore_index=-100,
202
+ use_ref_model=False,
203
+ ref_input_chunk=None,
204
+ ref_weight=None,
205
+ ref_bias=None,
206
+ **loss_kwargs,
207
+ ):
208
+ """
209
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
210
+ Args:
211
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
212
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
213
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
214
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
215
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
216
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
217
+ ignore_index (int): Index to ignore for loss computation.
218
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
219
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
220
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
221
+ loss_kwargs (dict): Additional arguments for the loss function.
222
+ """
223
+ average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
224
+ input_chunk,
225
+ weight,
226
+ target_chunk,
227
+ bias=bias,
228
+ ignore_index=ignore_index,
229
+ )
230
+
231
+ if use_ref_model:
232
+ with torch.no_grad():
233
+ ref_average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
234
+ ref_input_chunk,
235
+ ref_weight,
236
+ target_chunk,
237
+ ref_bias,
238
+ ignore_index=ignore_index,
239
+ )
240
+ loss_kwargs["ref_average_log_prob_chunk"] = ref_average_log_prob_chunk
241
+
242
+ preference_loss_chunk = preference_loss_fn(
243
+ average_log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
244
+ )
245
+
246
+ return preference_loss_chunk