liger-kernel 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/__init__.py +4 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +107 -0
  4. liger_kernel/chunked_loss/dpo_loss.py +95 -17
  5. liger_kernel/chunked_loss/functional.py +9 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +245 -65
  8. liger_kernel/chunked_loss/orpo_loss.py +63 -13
  9. liger_kernel/chunked_loss/simpo_loss.py +115 -0
  10. liger_kernel/env_report.py +22 -0
  11. liger_kernel/ops/cross_entropy.py +17 -10
  12. liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
  13. liger_kernel/ops/fused_linear_jsd.py +1 -1
  14. liger_kernel/ops/jsd.py +19 -10
  15. liger_kernel/ops/layer_norm.py +6 -1
  16. liger_kernel/ops/qwen2vl_mrope.py +238 -0
  17. liger_kernel/ops/rms_norm.py +6 -1
  18. liger_kernel/ops/utils.py +5 -2
  19. liger_kernel/transformers/functional.py +128 -11
  20. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  21. liger_kernel/transformers/jsd.py +1 -4
  22. liger_kernel/transformers/monkey_patch.py +6 -4
  23. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  24. liger_kernel/transformers/trainer/__init__.py +6 -0
  25. liger_kernel/transformers/trainer/orpo_trainer.py +169 -0
  26. liger_kernel/utils.py +13 -0
  27. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/METADATA +71 -47
  28. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/RECORD +32 -22
  29. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/WHEEL +1 -1
  30. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/LICENSE +0 -0
  31. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/NOTICE +0 -0
  32. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/top_level.txt +0 -0
@@ -8,13 +8,9 @@ from torch.nn import functional as F
8
8
  class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9
9
 
10
10
  @abstractmethod
11
- def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
11
+ def preference_loss_fn(*args, **kwargs):
12
12
  """
13
- Compute preference loss.
14
- Args:
15
- chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
- rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
- beta (float): Weight for the odds ratio loss.
13
+ To be extended by subclasses.
18
14
  """
19
15
  raise NotImplementedError("Preference loss function must be implemented.")
20
16
 
@@ -27,15 +23,29 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
27
23
  bias=None,
28
24
  loss_fn=None,
29
25
  chunk_size=1,
30
- compute_nll_loss=True,
31
26
  ignore_index=-100,
27
+ alpha=1.0,
32
28
  beta=0.1,
29
+ compute_nll_loss=True,
33
30
  compiled=True,
31
+ use_ref_model=False,
32
+ # TODO: ref input
33
+ ref_weight=None,
34
+ ref_bias=None,
35
+ **loss_kwargs,
34
36
  ):
35
37
  """
36
38
  Base class for fused linear layer with preference loss.
37
39
  Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
38
40
 
41
+ The mental model is:
42
+
43
+ forward()
44
+ ├── Loop over chunks
45
+ └── compute_loss()
46
+ ├── chunk_forward() # Compute logits and log probs
47
+ └── prefer_loss() # Calculate preference loss
48
+
39
49
  Args:
40
50
  _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
41
51
  weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
@@ -43,55 +53,130 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
43
53
  bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
44
54
  loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
45
55
  chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
46
- compute_nll_loss (bool): Whether to compute NLL loss.
47
56
  ignore_index (int): Index to ignore for loss computation.
48
- beta (float): Weight for the odds ratio loss.
57
+ alpha (float): Weight for the NLL loss.
58
+ beta (float): Weight for the preference loss.
59
+ compute_nll_loss (bool): Whether to compute NLL loss.
49
60
  compiled (bool): Whether to use torch compile for chunk accumulation.
61
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
64
+ loss_kwargs (dict): Other possible arguments that a loss function might need
50
65
  """
51
66
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
52
67
  CHUNK_SIZE = chunk_size
53
68
 
69
+ # Gradients to be accumulated
54
70
  grad_weight = torch.zeros_like(weight)
55
71
  grad_chosen_inputs = []
56
72
  grad_rejected_inputs = []
57
73
  grad_bias = torch.zeros_like(bias) if bias is not None else None
74
+
75
+ # Loss to be accumulated
58
76
  loss_acc = torch.zeros((), device=_input.device)
59
77
 
60
- chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
61
- loss_func_to_call = partial(
78
+ # Metrics to be recorded
79
+ policy_chosen_logps = []
80
+ policy_rejected_logps = []
81
+ policy_chosen_logits_mean = torch.zeros((), device=_input.device)
82
+ policy_rejected_logits_mean = torch.zeros((), device=_input.device)
83
+ policy_nll_loss = torch.zeros((), device=_input.device)
84
+ aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
85
+
86
+ compute_loss = partial(
62
87
  LigerFusedLinearPreferenceBase._compute_loss,
63
88
  preference_loss_fn=loss_fn,
64
89
  ignore_index=ignore_index,
90
+ alpha=alpha,
65
91
  beta=beta,
66
92
  compute_nll_loss=compute_nll_loss,
67
93
  full_target=target,
94
+ use_ref_model=use_ref_model,
95
+ ref_weight=ref_weight,
96
+ ref_bias=ref_bias,
97
+ **loss_kwargs,
68
98
  )
69
99
 
100
+ def fused_fwd_bwd(input_chunk, target_chunk):
101
+ """
102
+ Fused forward and backward pass for a chunk of input and target.
103
+ """
104
+ if bias is not None:
105
+ return torch.func.grad_and_value(
106
+ compute_loss, argnums=(0, 1, 3), has_aux=True
107
+ )(input_chunk, weight, target_chunk, bias)
108
+ else:
109
+ return torch.func.grad_and_value(
110
+ compute_loss, argnums=(0, 1), has_aux=True
111
+ )(input_chunk, weight, target_chunk)
112
+
70
113
  def accumulate_chunk(input_chunk, target_chunk):
71
114
  if bias is not None:
72
115
  (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
73
116
  chunk_loss,
74
- (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
75
- ) = torch.func.grad_and_value(
76
- loss_func_to_call, argnums=(0, 1, 3), has_aux=True
77
- )(
78
- input_chunk, weight, target_chunk, bias
79
- )
80
- grad_bias.add_(chunk_grad_bias)
117
+ (
118
+ chunk_chosen_logps,
119
+ chunk_rejected_logps,
120
+ chunk_chosen_logits_mean,
121
+ chunk_rejected_logits_mean,
122
+ chunk_nll_loss,
123
+ *aux_outputs,
124
+ ),
125
+ ) = fused_fwd_bwd(input_chunk, target_chunk)
126
+ grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
81
127
  else:
82
128
  (chunk_grad_input, chunk_grad_weight), (
83
129
  chunk_loss,
84
- (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
85
- ) = torch.func.grad_and_value(
86
- loss_func_to_call, argnums=(0, 1), has_aux=True
87
- )(
88
- input_chunk, weight, target_chunk
89
- )
130
+ (
131
+ chunk_chosen_logps,
132
+ chunk_rejected_logps,
133
+ chunk_chosen_logits_mean,
134
+ chunk_rejected_logits_mean,
135
+ chunk_nll_loss,
136
+ *aux_outputs,
137
+ ),
138
+ ) = fused_fwd_bwd(input_chunk, target_chunk)
139
+
140
+ # Accumulate gradients
90
141
  grad_weight.add_(chunk_grad_weight)
142
+ grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
143
+ grad_rejected_inputs.append(
144
+ chunk_grad_input[chosen_target_chunk.shape[0] :]
145
+ )
146
+
147
+ # Accumulate loss
91
148
  loss_acc.add_(chunk_loss)
92
- return chunk_grad_input
149
+
150
+ # Accumulate metrics
151
+ policy_chosen_logps.append(chunk_chosen_logps)
152
+ policy_rejected_logps.append(chunk_rejected_logps)
153
+ policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
154
+ policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
155
+ policy_nll_loss.add_(chunk_nll_loss)
156
+
157
+ # aux_outputs
158
+ # Initialize storage for aux_outputs
159
+ if len(aggregated_aux_outputs) == 0:
160
+ for aux in aux_outputs:
161
+ if aux.ndim == 0:
162
+ aggregated_aux_outputs.append(
163
+ torch.zeros((), device=aux.device)
164
+ )
165
+ else:
166
+ aggregated_aux_outputs.append([])
167
+
168
+ # Process each aux_output
169
+ for i, aux in enumerate(aux_outputs):
170
+ if aux.ndim == 0:
171
+ aggregated_aux_outputs[i].add_(aux)
172
+ else:
173
+ aggregated_aux_outputs[i].append(aux)
174
+
175
+ if compiled:
176
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
93
177
 
94
178
  len_chosen = target.shape[0] // 2
179
+ chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
95
180
  _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
96
181
  _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
97
182
  _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
@@ -113,62 +198,61 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
113
198
  [chosen_target_chunk, rejected_target_chunk], dim=0
114
199
  )
115
200
 
116
- if compiled:
117
- accumulate_chunk = torch.compile(accumulate_chunk)
118
- grad_input = accumulate_chunk(input_chunk, target_chunk)
201
+ # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
202
+ torch._dynamo.mark_dynamic(input_chunk, 1)
203
+ torch._dynamo.mark_dynamic(target_chunk, 1)
204
+ torch._dynamo.mark_dynamic(target, 1)
119
205
 
120
- grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
121
- grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
206
+ # accumulate loss, gradients, and metrics
207
+ accumulate_chunk(input_chunk, target_chunk)
122
208
 
123
209
  # combine grad_chosen_inputs and grad_rejected_inputs
124
210
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
211
+ policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0)
212
+ policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0)
213
+
214
+ # Aggregate aux outputs lists into tensors
215
+ for i, aux in enumerate(aggregated_aux_outputs):
216
+ if isinstance(aux, list):
217
+ aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
125
218
 
126
219
  ctx.save_for_backward(
127
220
  torch.cat(grad_inputs, dim=0),
128
221
  grad_weight,
129
222
  grad_bias,
130
223
  )
131
- return loss_acc
224
+ return_vars = (
225
+ policy_chosen_logps,
226
+ policy_rejected_logps,
227
+ policy_chosen_logits_mean,
228
+ policy_rejected_logits_mean,
229
+ policy_nll_loss,
230
+ )
231
+ return loss_acc, (*return_vars, *aggregated_aux_outputs)
132
232
 
133
233
  @staticmethod
134
- def backward(ctx, grad_output):
234
+ def backward(ctx, *grad_output):
135
235
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
136
- if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
137
- grad_input = grad_input * grad_output
138
- grad_weight = grad_weight * grad_output
139
- grad_bias = grad_bias * grad_output if grad_bias is not None else None
236
+ if torch.ne(
237
+ grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
238
+ ):
239
+ grad_input = grad_input * grad_output[0][0]
240
+ grad_weight = grad_weight * grad_output[0][0]
241
+ grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
140
242
 
141
243
  return grad_input, grad_weight, None, grad_bias, None, None, None
142
244
 
143
245
  @staticmethod
144
- def _compute_loss(
246
+ def chunk_forward(
145
247
  input_chunk,
146
248
  weight,
147
249
  target_chunk,
148
250
  bias=None,
149
- preference_loss_fn=None,
150
- full_target=None,
151
251
  ignore_index=-100,
152
- beta=0.1,
153
252
  compute_nll_loss=True,
154
- **loss_kwargs,
155
253
  ):
156
- """
157
- Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
158
- Args:
159
- preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
160
- input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
161
- weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
162
- target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
163
- bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
164
- full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
165
- ignore_index (int): Index to ignore for loss computation.
166
- beta (float): Weight for the odds ratio loss.
167
- loss_kwargs (dict): Additional arguments for the loss function.
168
- """
169
254
  len_chosen_chunk = target_chunk.shape[0] // 2
170
-
171
- logits_chunk = input_chunk @ weight.t() # chunk_size x V
255
+ logits_chunk = input_chunk @ weight.t()
172
256
  if bias is not None:
173
257
  logits_chunk = logits_chunk + bias
174
258
  log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
@@ -181,10 +265,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
181
265
  reduction="sum",
182
266
  ignore_index=ignore_index,
183
267
  )
184
- chosen_nll_loss = (
185
- chosen_nll_loss
186
- / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
187
- )
188
268
 
189
269
  loss_mask = target_chunk != ignore_index
190
270
  label_chunk = torch.where(loss_mask, target_chunk, 0)
@@ -197,10 +277,110 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
197
277
  chosen_logps = average_log_prob[:len_chosen_chunk]
198
278
  rejected_logps = average_log_prob[len_chosen_chunk:]
199
279
 
200
- alignment_loss = preference_loss_fn(
201
- chosen_logps, rejected_logps, beta=beta, **loss_kwargs
280
+ chosen_logits = logits_chunk[:len_chosen_chunk]
281
+ rejected_logits = logits_chunk[len_chosen_chunk:]
282
+
283
+ return (
284
+ chosen_logps,
285
+ rejected_logps,
286
+ chosen_logits,
287
+ rejected_logits,
288
+ chosen_nll_loss,
289
+ )
290
+
291
+ @staticmethod
292
+ def _compute_loss(
293
+ input_chunk,
294
+ weight,
295
+ target_chunk,
296
+ bias=None,
297
+ preference_loss_fn=None,
298
+ full_target=None,
299
+ ignore_index=-100,
300
+ alpha=1.0,
301
+ beta=0.1,
302
+ compute_nll_loss=True,
303
+ use_ref_model=False,
304
+ ref_weight=None,
305
+ ref_bias=None,
306
+ **loss_kwargs,
307
+ ):
308
+ """
309
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
310
+ Args:
311
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
312
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
313
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
314
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
315
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
316
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
317
+ ignore_index (int): Index to ignore for loss computation.
318
+ alpha (float): Weight for the NLL loss.
319
+ beta (float): Weight for the preference loss.
320
+ compute_nll_loss (bool): Whether to compute NLL loss.
321
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
322
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
323
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
324
+ loss_kwargs (dict): Additional arguments for the loss function.
325
+ """
326
+ (
327
+ chosen_logps,
328
+ rejected_logps,
329
+ chosen_logits,
330
+ rejected_logits,
331
+ chosen_nll_loss,
332
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
333
+ input_chunk,
334
+ weight,
335
+ target_chunk,
336
+ bias=bias,
337
+ ignore_index=ignore_index,
338
+ compute_nll_loss=compute_nll_loss,
339
+ )
340
+ chosen_nll_loss = (
341
+ chosen_nll_loss
342
+ / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
343
+ )
344
+ chosen_logits_mean = chosen_logits.sum() / (
345
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
346
+ )
347
+ rejected_logits_mean = rejected_logits.sum() / (
348
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
349
+ )
350
+
351
+ if use_ref_model:
352
+ with torch.no_grad():
353
+ (
354
+ ref_chosen_logps,
355
+ ref_rejected_logps,
356
+ ref_chosen_logits,
357
+ ref_rejected_logits,
358
+ ref_chosen_nll_loss,
359
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
+ input_chunk,
361
+ ref_weight,
362
+ target_chunk,
363
+ ref_bias,
364
+ ignore_index=ignore_index,
365
+ compute_nll_loss=False, # We don't need NLL loss for the reference model
366
+ )
367
+ loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
368
+ loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
369
+
370
+ preference_loss_outputs = preference_loss_fn(
371
+ chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
202
372
  )
203
- alignment_loss = alignment_loss / (full_target.shape[0] // 2)
373
+ if isinstance(preference_loss_outputs, tuple):
374
+ preference_loss, *aux_outputs = preference_loss_outputs
375
+ else:
376
+ preference_loss, aux_outputs = preference_loss_outputs, []
204
377
 
205
- loss = chosen_nll_loss - alignment_loss
206
- return loss, (alignment_loss, chosen_logps, rejected_logps)
378
+ loss = alpha * chosen_nll_loss - preference_loss
379
+ return_vars = (
380
+ chosen_logps,
381
+ rejected_logps,
382
+ chosen_logits_mean,
383
+ rejected_logits_mean,
384
+ chosen_nll_loss,
385
+ )
386
+ return loss, (*return_vars, *aux_outputs)
@@ -9,12 +9,26 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
9
9
  class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10
10
 
11
11
  @staticmethod
12
- def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
- Compute odds-ratio loss.
14
+ Paper: https://arxiv.org/pdf/2403.07691
15
+
16
+ Formula:
17
+ Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
18
+ where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
19
+
20
+ Where:
21
+ - P_θ(y|x): Policy (model) probability
22
+ - y_w: Chosen sequence
23
+ - y_l: Rejected sequence
24
+ - σ: Sigmoid function
25
+ - β: Weight for the odds ratio loss
26
+ - odds_θ: Odds function for the policy
27
+
15
28
  Args:
16
29
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
30
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
+ full_target (torch.Tensor): Non chunked full target tensor
18
32
  beta (float): Weight for the odds ratio loss.
19
33
  """
20
34
  log_odds = (chosen_logps - rejected_logps) - (
@@ -22,7 +36,15 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
22
36
  - torch.log1p(-torch.exp(rejected_logps))
23
37
  )
24
38
  ratio = F.logsigmoid(log_odds)
25
- return beta * ratio.sum()
39
+ loss = beta * ratio.sum() / (full_target.shape[0] // 2)
40
+
41
+ chosen_rewards = beta * chosen_logps
42
+ rejected_rewards = beta * rejected_logps
43
+
44
+ log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
45
+ log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
46
+
47
+ return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
26
48
 
27
49
  @staticmethod
28
50
  def forward(
@@ -36,12 +58,6 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
36
58
  compute_nll_loss=True,
37
59
  compiled=True,
38
60
  ):
39
- """
40
- Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
41
- Handles both the forward and backward pass of the final linear layer with ORPO loss.
42
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
43
- """
44
-
45
61
  return LigerFusedLinearPreferenceBase.forward(
46
62
  ctx=ctx,
47
63
  _input=_input,
@@ -49,15 +65,49 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
49
65
  target=target,
50
66
  bias=bias,
51
67
  loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
52
- compute_nll_loss=compute_nll_loss,
53
68
  ignore_index=ignore_index,
54
69
  beta=beta,
70
+ compute_nll_loss=compute_nll_loss,
55
71
  compiled=compiled,
56
72
  )
57
73
 
58
74
  @staticmethod
59
- def backward(ctx, grad_output):
60
- # Get gradients for _input, weight, bias, and target from the base class
75
+ def backward(ctx, *grad_output):
61
76
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
62
- # Return these gradients, followed by None for the remaining inputs
63
77
  return *grads, None, None, None, None
78
+
79
+
80
+ class LigerFusedLinearORPOLoss(torch.nn.Module):
81
+ """
82
+ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ ignore_index: int = -100,
88
+ beta: float = 0.1,
89
+ compute_nll_loss: bool = True,
90
+ compiled: bool = True,
91
+ ):
92
+ """
93
+ Args:
94
+ ignore_index (int): Index to ignore in the loss.
95
+ beta (float): Weight for the odds ratio loss.
96
+ """
97
+ super().__init__()
98
+ self.ignore_index = ignore_index
99
+ self.beta = beta
100
+ self.compute_nll_loss = compute_nll_loss
101
+ self.compiled = compiled
102
+
103
+ def forward(self, lin_weight, _input, target, bias=None):
104
+ return LigerFusedLinearORPOFunction.apply(
105
+ _input,
106
+ lin_weight,
107
+ target,
108
+ bias,
109
+ self.ignore_index,
110
+ self.beta,
111
+ self.compute_nll_loss,
112
+ self.compiled,
113
+ )
@@ -0,0 +1,115 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import (
5
+ LigerFusedLinearPreferenceBase,
6
+ )
7
+
8
+
9
+ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10
+
11
+ @staticmethod
12
+ def preference_loss_fn(
13
+ chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
14
+ ):
15
+ """
16
+ Paper: https://arxiv.org/pdf/2405.14734
17
+
18
+ Formula:
19
+ L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
20
+
21
+ Where:
22
+ - π_θ(y|x): Policy (model) probability
23
+ - y_w: Chosen sequence
24
+ - y_l: Rejected sequence
25
+ - |y_w|, |y_l|: Sequence lengths
26
+ - σ: Sigmoid function
27
+ - β: beta weight
28
+ - γ: gemma margin term
29
+
30
+ Args:
31
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
32
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
33
+ full_target: Non chunked full target tensor
34
+ beta (float): beta weight
35
+ gamma (float): gemma margin term
36
+ """
37
+ logits = beta * (chosen_logps - rejected_logps) - gamma
38
+ loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
39
+ return loss
40
+
41
+ @staticmethod
42
+ def forward(
43
+ ctx,
44
+ _input,
45
+ weight,
46
+ target,
47
+ bias=None,
48
+ ignore_index=-100,
49
+ beta=0.1,
50
+ alpha=1.0,
51
+ compute_nll_loss=False,
52
+ compiled=True,
53
+ gamma=0.5,
54
+ ):
55
+ return LigerFusedLinearPreferenceBase.forward(
56
+ ctx,
57
+ _input,
58
+ weight,
59
+ target,
60
+ bias,
61
+ loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
62
+ compute_nll_loss=compute_nll_loss,
63
+ ignore_index=ignore_index,
64
+ alpha=alpha,
65
+ beta=beta,
66
+ compiled=compiled,
67
+ gamma=gamma,
68
+ )
69
+
70
+ @staticmethod
71
+ def backward(ctx, *grad_output):
72
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
73
+ return *grads, None, None, None, None, None, None
74
+
75
+
76
+ class LigerFusedLinearSimPOLoss(torch.nn.Module):
77
+ """
78
+ Fused linear layer with SimPO loss.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ ignore_index: int = -100,
84
+ beta: float = 0.1,
85
+ alpha: float = 1.0,
86
+ compute_nll_loss: bool = True,
87
+ compiled: bool = True,
88
+ gamma: float = 0.5,
89
+ ):
90
+ """
91
+ Args:
92
+ ignore_index (int): Index to ignore in the loss.
93
+ beta (float): Weight for the odds ratio loss.
94
+ """
95
+ super().__init__()
96
+ self.ignore_index = ignore_index
97
+ self.beta = beta
98
+ self.alpha = alpha
99
+ self.compute_nll_loss = compute_nll_loss
100
+ self.compiled = compiled
101
+ self.gamma = gamma
102
+
103
+ def forward(self, lin_weight, _input, target, bias=None):
104
+ return LigerFusedLinearSimPOFunction.apply(
105
+ _input,
106
+ lin_weight,
107
+ target,
108
+ bias,
109
+ self.ignore_index,
110
+ self.beta,
111
+ self.alpha,
112
+ self.compute_nll_loss,
113
+ self.compiled,
114
+ self.gamma,
115
+ )
@@ -1,5 +1,6 @@
1
1
  import platform
2
2
  import sys
3
+ from importlib.metadata import version
3
4
 
4
5
 
5
6
  def print_env_report():
@@ -17,6 +18,11 @@ def print_env_report():
17
18
  print(f"Operating System: {platform.platform()}")
18
19
  print(f"Python version: {sys.version.split()[0]}")
19
20
 
21
+ try:
22
+ print(f"Liger Kernel version: {version('liger-kernel')}")
23
+ except ImportError:
24
+ print("Liger Kernel: Not installed")
25
+
20
26
  try:
21
27
  import torch
22
28
 
@@ -25,9 +31,17 @@ def print_env_report():
25
31
  torch.version.cuda if torch.cuda.is_available() else "Not available"
26
32
  )
27
33
  print(f"CUDA version: {cuda_version}")
34
+ hip_version = (
35
+ torch.version.hip
36
+ if torch.cuda.is_available() and torch.version.hip
37
+ else "Not available"
38
+ )
39
+ print(f"HIP(ROCm) version: {hip_version}")
40
+
28
41
  except ImportError:
29
42
  print("PyTorch: Not installed")
30
43
  print("CUDA version: Unable to query")
44
+ print("HIP(ROCm) version: Unable to query")
31
45
 
32
46
  try:
33
47
  import triton
@@ -43,6 +57,14 @@ def print_env_report():
43
57
  except ImportError:
44
58
  print("Transformers: Not installed")
45
59
 
60
+ try:
61
+ xpu_version = (
62
+ torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
63
+ )
64
+ print(f"XPU version: {xpu_version}")
65
+ except ImportError:
66
+ print("XPU version: Unable to query")
67
+
46
68
 
47
69
  if __name__ == "__main__":
48
70
  print_env_report()