liger-kernel 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/__init__.py +4 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +107 -0
  4. liger_kernel/chunked_loss/dpo_loss.py +135 -0
  5. liger_kernel/chunked_loss/functional.py +9 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
  8. liger_kernel/chunked_loss/orpo_loss.py +113 -0
  9. liger_kernel/chunked_loss/simpo_loss.py +115 -0
  10. liger_kernel/env_report.py +22 -0
  11. liger_kernel/ops/cross_entropy.py +17 -10
  12. liger_kernel/ops/fused_linear_cross_entropy.py +1 -11
  13. liger_kernel/ops/fused_linear_jsd.py +1 -1
  14. liger_kernel/ops/jsd.py +19 -10
  15. liger_kernel/ops/layer_norm.py +6 -1
  16. liger_kernel/ops/qwen2vl_mrope.py +238 -0
  17. liger_kernel/ops/rms_norm.py +6 -1
  18. liger_kernel/ops/utils.py +5 -2
  19. liger_kernel/transformers/__init__.py +1 -0
  20. liger_kernel/transformers/functional.py +128 -11
  21. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  22. liger_kernel/transformers/jsd.py +1 -4
  23. liger_kernel/transformers/model/qwen2_vl.py +43 -17
  24. liger_kernel/transformers/monkey_patch.py +11 -6
  25. liger_kernel/transformers/orpo_trainer.py +171 -0
  26. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  27. liger_kernel/utils.py +13 -0
  28. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/METADATA +80 -123
  29. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/RECORD +33 -20
  30. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/WHEEL +1 -1
  31. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/LICENSE +0 -0
  32. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/NOTICE +0 -0
  33. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9
+
10
+ @abstractmethod
11
+ def preference_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("Preference loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ ctx,
20
+ _input,
21
+ weight,
22
+ target,
23
+ bias=None,
24
+ loss_fn=None,
25
+ chunk_size=1,
26
+ ignore_index=-100,
27
+ alpha=1.0,
28
+ beta=0.1,
29
+ compute_nll_loss=True,
30
+ compiled=True,
31
+ use_ref_model=False,
32
+ # TODO: ref input
33
+ ref_weight=None,
34
+ ref_bias=None,
35
+ **loss_kwargs,
36
+ ):
37
+ """
38
+ Base class for fused linear layer with preference loss.
39
+ Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
40
+
41
+ The mental model is:
42
+
43
+ forward()
44
+ ├── Loop over chunks
45
+ └── compute_loss()
46
+ ├── chunk_forward() # Compute logits and log probs
47
+ └── prefer_loss() # Calculate preference loss
48
+
49
+ Args:
50
+ _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
51
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
52
+ target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
53
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
54
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
55
+ chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
56
+ ignore_index (int): Index to ignore for loss computation.
57
+ alpha (float): Weight for the NLL loss.
58
+ beta (float): Weight for the preference loss.
59
+ compute_nll_loss (bool): Whether to compute NLL loss.
60
+ compiled (bool): Whether to use torch compile for chunk accumulation.
61
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
64
+ loss_kwargs (dict): Other possible arguments that a loss function might need
65
+ """
66
+ # TODO: Tune CHUNK_SIZE to fully utilize the GPU
67
+ CHUNK_SIZE = chunk_size
68
+
69
+ # Gradients to be accumulated
70
+ grad_weight = torch.zeros_like(weight)
71
+ grad_chosen_inputs = []
72
+ grad_rejected_inputs = []
73
+ grad_bias = torch.zeros_like(bias) if bias is not None else None
74
+
75
+ # Loss to be accumulated
76
+ loss_acc = torch.zeros((), device=_input.device)
77
+
78
+ # Metrics to be recorded
79
+ policy_chosen_logps = []
80
+ policy_rejected_logps = []
81
+ policy_chosen_logits_mean = torch.zeros((), device=_input.device)
82
+ policy_rejected_logits_mean = torch.zeros((), device=_input.device)
83
+ policy_nll_loss = torch.zeros((), device=_input.device)
84
+ aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
85
+
86
+ compute_loss = partial(
87
+ LigerFusedLinearPreferenceBase._compute_loss,
88
+ preference_loss_fn=loss_fn,
89
+ ignore_index=ignore_index,
90
+ alpha=alpha,
91
+ beta=beta,
92
+ compute_nll_loss=compute_nll_loss,
93
+ full_target=target,
94
+ use_ref_model=use_ref_model,
95
+ ref_weight=ref_weight,
96
+ ref_bias=ref_bias,
97
+ **loss_kwargs,
98
+ )
99
+
100
+ def fused_fwd_bwd(input_chunk, target_chunk):
101
+ """
102
+ Fused forward and backward pass for a chunk of input and target.
103
+ """
104
+ if bias is not None:
105
+ return torch.func.grad_and_value(
106
+ compute_loss, argnums=(0, 1, 3), has_aux=True
107
+ )(input_chunk, weight, target_chunk, bias)
108
+ else:
109
+ return torch.func.grad_and_value(
110
+ compute_loss, argnums=(0, 1), has_aux=True
111
+ )(input_chunk, weight, target_chunk)
112
+
113
+ def accumulate_chunk(input_chunk, target_chunk):
114
+ if bias is not None:
115
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
116
+ chunk_loss,
117
+ (
118
+ chunk_chosen_logps,
119
+ chunk_rejected_logps,
120
+ chunk_chosen_logits_mean,
121
+ chunk_rejected_logits_mean,
122
+ chunk_nll_loss,
123
+ *aux_outputs,
124
+ ),
125
+ ) = fused_fwd_bwd(input_chunk, target_chunk)
126
+ grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
127
+ else:
128
+ (chunk_grad_input, chunk_grad_weight), (
129
+ chunk_loss,
130
+ (
131
+ chunk_chosen_logps,
132
+ chunk_rejected_logps,
133
+ chunk_chosen_logits_mean,
134
+ chunk_rejected_logits_mean,
135
+ chunk_nll_loss,
136
+ *aux_outputs,
137
+ ),
138
+ ) = fused_fwd_bwd(input_chunk, target_chunk)
139
+
140
+ # Accumulate gradients
141
+ grad_weight.add_(chunk_grad_weight)
142
+ grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
143
+ grad_rejected_inputs.append(
144
+ chunk_grad_input[chosen_target_chunk.shape[0] :]
145
+ )
146
+
147
+ # Accumulate loss
148
+ loss_acc.add_(chunk_loss)
149
+
150
+ # Accumulate metrics
151
+ policy_chosen_logps.append(chunk_chosen_logps)
152
+ policy_rejected_logps.append(chunk_rejected_logps)
153
+ policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
154
+ policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
155
+ policy_nll_loss.add_(chunk_nll_loss)
156
+
157
+ # aux_outputs
158
+ # Initialize storage for aux_outputs
159
+ if len(aggregated_aux_outputs) == 0:
160
+ for aux in aux_outputs:
161
+ if aux.ndim == 0:
162
+ aggregated_aux_outputs.append(
163
+ torch.zeros((), device=aux.device)
164
+ )
165
+ else:
166
+ aggregated_aux_outputs.append([])
167
+
168
+ # Process each aux_output
169
+ for i, aux in enumerate(aux_outputs):
170
+ if aux.ndim == 0:
171
+ aggregated_aux_outputs[i].add_(aux)
172
+ else:
173
+ aggregated_aux_outputs[i].append(aux)
174
+
175
+ if compiled:
176
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
177
+
178
+ len_chosen = target.shape[0] // 2
179
+ chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
180
+ _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
181
+ _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
182
+ _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
183
+ _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
184
+
185
+ for (
186
+ chosen_input_chunk,
187
+ rejected_input_chunk,
188
+ chosen_target_chunk,
189
+ rejected_target_chunk,
190
+ ) in zip(
191
+ _chosen_input_chunks,
192
+ _rejected_input_chunks,
193
+ _chosen_target_chunks,
194
+ _rejected_target_chunks,
195
+ ):
196
+ input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
197
+ target_chunk = torch.cat(
198
+ [chosen_target_chunk, rejected_target_chunk], dim=0
199
+ )
200
+
201
+ # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
202
+ torch._dynamo.mark_dynamic(input_chunk, 1)
203
+ torch._dynamo.mark_dynamic(target_chunk, 1)
204
+ torch._dynamo.mark_dynamic(target, 1)
205
+
206
+ # accumulate loss, gradients, and metrics
207
+ accumulate_chunk(input_chunk, target_chunk)
208
+
209
+ # combine grad_chosen_inputs and grad_rejected_inputs
210
+ grad_inputs = grad_chosen_inputs + grad_rejected_inputs
211
+ policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0)
212
+ policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0)
213
+
214
+ # Aggregate aux outputs lists into tensors
215
+ for i, aux in enumerate(aggregated_aux_outputs):
216
+ if isinstance(aux, list):
217
+ aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
218
+
219
+ ctx.save_for_backward(
220
+ torch.cat(grad_inputs, dim=0),
221
+ grad_weight,
222
+ grad_bias,
223
+ )
224
+ return_vars = (
225
+ policy_chosen_logps,
226
+ policy_rejected_logps,
227
+ policy_chosen_logits_mean,
228
+ policy_rejected_logits_mean,
229
+ policy_nll_loss,
230
+ )
231
+ return loss_acc, (*return_vars, *aggregated_aux_outputs)
232
+
233
+ @staticmethod
234
+ def backward(ctx, *grad_output):
235
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
236
+ if torch.ne(
237
+ grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
238
+ ):
239
+ grad_input = grad_input * grad_output[0][0]
240
+ grad_weight = grad_weight * grad_output[0][0]
241
+ grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
242
+
243
+ return grad_input, grad_weight, None, grad_bias, None, None, None
244
+
245
+ @staticmethod
246
+ def chunk_forward(
247
+ input_chunk,
248
+ weight,
249
+ target_chunk,
250
+ bias=None,
251
+ ignore_index=-100,
252
+ compute_nll_loss=True,
253
+ ):
254
+ len_chosen_chunk = target_chunk.shape[0] // 2
255
+ logits_chunk = input_chunk @ weight.t()
256
+ if bias is not None:
257
+ logits_chunk = logits_chunk + bias
258
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
259
+
260
+ chosen_nll_loss = 0.0
261
+ if compute_nll_loss:
262
+ chosen_nll_loss = F.nll_loss(
263
+ log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
264
+ target_chunk[:len_chosen_chunk].view(-1),
265
+ reduction="sum",
266
+ ignore_index=ignore_index,
267
+ )
268
+
269
+ loss_mask = target_chunk != ignore_index
270
+ label_chunk = torch.where(loss_mask, target_chunk, 0)
271
+
272
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
273
+ -1
274
+ )
275
+ average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
276
+
277
+ chosen_logps = average_log_prob[:len_chosen_chunk]
278
+ rejected_logps = average_log_prob[len_chosen_chunk:]
279
+
280
+ chosen_logits = logits_chunk[:len_chosen_chunk]
281
+ rejected_logits = logits_chunk[len_chosen_chunk:]
282
+
283
+ return (
284
+ chosen_logps,
285
+ rejected_logps,
286
+ chosen_logits,
287
+ rejected_logits,
288
+ chosen_nll_loss,
289
+ )
290
+
291
+ @staticmethod
292
+ def _compute_loss(
293
+ input_chunk,
294
+ weight,
295
+ target_chunk,
296
+ bias=None,
297
+ preference_loss_fn=None,
298
+ full_target=None,
299
+ ignore_index=-100,
300
+ alpha=1.0,
301
+ beta=0.1,
302
+ compute_nll_loss=True,
303
+ use_ref_model=False,
304
+ ref_weight=None,
305
+ ref_bias=None,
306
+ **loss_kwargs,
307
+ ):
308
+ """
309
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
310
+ Args:
311
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
312
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
313
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
314
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
315
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
316
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
317
+ ignore_index (int): Index to ignore for loss computation.
318
+ alpha (float): Weight for the NLL loss.
319
+ beta (float): Weight for the preference loss.
320
+ compute_nll_loss (bool): Whether to compute NLL loss.
321
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
322
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
323
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
324
+ loss_kwargs (dict): Additional arguments for the loss function.
325
+ """
326
+ (
327
+ chosen_logps,
328
+ rejected_logps,
329
+ chosen_logits,
330
+ rejected_logits,
331
+ chosen_nll_loss,
332
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
333
+ input_chunk,
334
+ weight,
335
+ target_chunk,
336
+ bias=bias,
337
+ ignore_index=ignore_index,
338
+ compute_nll_loss=compute_nll_loss,
339
+ )
340
+ chosen_nll_loss = (
341
+ chosen_nll_loss
342
+ / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
343
+ )
344
+ chosen_logits_mean = chosen_logits.sum() / (
345
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
346
+ )
347
+ rejected_logits_mean = rejected_logits.sum() / (
348
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
349
+ )
350
+
351
+ if use_ref_model:
352
+ with torch.no_grad():
353
+ (
354
+ ref_chosen_logps,
355
+ ref_rejected_logps,
356
+ ref_chosen_logits,
357
+ ref_rejected_logits,
358
+ ref_chosen_nll_loss,
359
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
+ input_chunk,
361
+ ref_weight,
362
+ target_chunk,
363
+ ref_bias,
364
+ ignore_index=ignore_index,
365
+ compute_nll_loss=False, # We don't need NLL loss for the reference model
366
+ )
367
+ loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
368
+ loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
369
+
370
+ preference_loss_outputs = preference_loss_fn(
371
+ chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
372
+ )
373
+ if isinstance(preference_loss_outputs, tuple):
374
+ preference_loss, *aux_outputs = preference_loss_outputs
375
+ else:
376
+ preference_loss, aux_outputs = preference_loss_outputs, []
377
+
378
+ loss = alpha * chosen_nll_loss - preference_loss
379
+ return_vars = (
380
+ chosen_logps,
381
+ rejected_logps,
382
+ chosen_logits_mean,
383
+ rejected_logits_mean,
384
+ chosen_nll_loss,
385
+ )
386
+ return loss, (*return_vars, *aux_outputs)
@@ -0,0 +1,113 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import (
5
+ LigerFusedLinearPreferenceBase,
6
+ )
7
+
8
+
9
+ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10
+
11
+ @staticmethod
12
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
+ """
14
+ Paper: https://arxiv.org/pdf/2403.07691
15
+
16
+ Formula:
17
+ Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
18
+ where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
19
+
20
+ Where:
21
+ - P_θ(y|x): Policy (model) probability
22
+ - y_w: Chosen sequence
23
+ - y_l: Rejected sequence
24
+ - σ: Sigmoid function
25
+ - β: Weight for the odds ratio loss
26
+ - odds_θ: Odds function for the policy
27
+
28
+ Args:
29
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
30
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
+ full_target (torch.Tensor): Non chunked full target tensor
32
+ beta (float): Weight for the odds ratio loss.
33
+ """
34
+ log_odds = (chosen_logps - rejected_logps) - (
35
+ torch.log1p(-torch.exp(chosen_logps))
36
+ - torch.log1p(-torch.exp(rejected_logps))
37
+ )
38
+ ratio = F.logsigmoid(log_odds)
39
+ loss = beta * ratio.sum() / (full_target.shape[0] // 2)
40
+
41
+ chosen_rewards = beta * chosen_logps
42
+ rejected_rewards = beta * rejected_logps
43
+
44
+ log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
45
+ log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
46
+
47
+ return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
48
+
49
+ @staticmethod
50
+ def forward(
51
+ ctx,
52
+ _input,
53
+ weight,
54
+ target,
55
+ bias=None,
56
+ ignore_index=-100,
57
+ beta=0.1,
58
+ compute_nll_loss=True,
59
+ compiled=True,
60
+ ):
61
+ return LigerFusedLinearPreferenceBase.forward(
62
+ ctx=ctx,
63
+ _input=_input,
64
+ weight=weight,
65
+ target=target,
66
+ bias=bias,
67
+ loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
68
+ ignore_index=ignore_index,
69
+ beta=beta,
70
+ compute_nll_loss=compute_nll_loss,
71
+ compiled=compiled,
72
+ )
73
+
74
+ @staticmethod
75
+ def backward(ctx, *grad_output):
76
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
77
+ return *grads, None, None, None, None
78
+
79
+
80
+ class LigerFusedLinearORPOLoss(torch.nn.Module):
81
+ """
82
+ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ ignore_index: int = -100,
88
+ beta: float = 0.1,
89
+ compute_nll_loss: bool = True,
90
+ compiled: bool = True,
91
+ ):
92
+ """
93
+ Args:
94
+ ignore_index (int): Index to ignore in the loss.
95
+ beta (float): Weight for the odds ratio loss.
96
+ """
97
+ super().__init__()
98
+ self.ignore_index = ignore_index
99
+ self.beta = beta
100
+ self.compute_nll_loss = compute_nll_loss
101
+ self.compiled = compiled
102
+
103
+ def forward(self, lin_weight, _input, target, bias=None):
104
+ return LigerFusedLinearORPOFunction.apply(
105
+ _input,
106
+ lin_weight,
107
+ target,
108
+ bias,
109
+ self.ignore_index,
110
+ self.beta,
111
+ self.compute_nll_loss,
112
+ self.compiled,
113
+ )
@@ -0,0 +1,115 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import (
5
+ LigerFusedLinearPreferenceBase,
6
+ )
7
+
8
+
9
+ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10
+
11
+ @staticmethod
12
+ def preference_loss_fn(
13
+ chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
14
+ ):
15
+ """
16
+ Paper: https://arxiv.org/pdf/2405.14734
17
+
18
+ Formula:
19
+ L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
20
+
21
+ Where:
22
+ - π_θ(y|x): Policy (model) probability
23
+ - y_w: Chosen sequence
24
+ - y_l: Rejected sequence
25
+ - |y_w|, |y_l|: Sequence lengths
26
+ - σ: Sigmoid function
27
+ - β: beta weight
28
+ - γ: gemma margin term
29
+
30
+ Args:
31
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
32
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
33
+ full_target: Non chunked full target tensor
34
+ beta (float): beta weight
35
+ gamma (float): gemma margin term
36
+ """
37
+ logits = beta * (chosen_logps - rejected_logps) - gamma
38
+ loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
39
+ return loss
40
+
41
+ @staticmethod
42
+ def forward(
43
+ ctx,
44
+ _input,
45
+ weight,
46
+ target,
47
+ bias=None,
48
+ ignore_index=-100,
49
+ beta=0.1,
50
+ alpha=1.0,
51
+ compute_nll_loss=False,
52
+ compiled=True,
53
+ gamma=0.5,
54
+ ):
55
+ return LigerFusedLinearPreferenceBase.forward(
56
+ ctx,
57
+ _input,
58
+ weight,
59
+ target,
60
+ bias,
61
+ loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
62
+ compute_nll_loss=compute_nll_loss,
63
+ ignore_index=ignore_index,
64
+ alpha=alpha,
65
+ beta=beta,
66
+ compiled=compiled,
67
+ gamma=gamma,
68
+ )
69
+
70
+ @staticmethod
71
+ def backward(ctx, *grad_output):
72
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
73
+ return *grads, None, None, None, None, None, None
74
+
75
+
76
+ class LigerFusedLinearSimPOLoss(torch.nn.Module):
77
+ """
78
+ Fused linear layer with SimPO loss.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ ignore_index: int = -100,
84
+ beta: float = 0.1,
85
+ alpha: float = 1.0,
86
+ compute_nll_loss: bool = True,
87
+ compiled: bool = True,
88
+ gamma: float = 0.5,
89
+ ):
90
+ """
91
+ Args:
92
+ ignore_index (int): Index to ignore in the loss.
93
+ beta (float): Weight for the odds ratio loss.
94
+ """
95
+ super().__init__()
96
+ self.ignore_index = ignore_index
97
+ self.beta = beta
98
+ self.alpha = alpha
99
+ self.compute_nll_loss = compute_nll_loss
100
+ self.compiled = compiled
101
+ self.gamma = gamma
102
+
103
+ def forward(self, lin_weight, _input, target, bias=None):
104
+ return LigerFusedLinearSimPOFunction.apply(
105
+ _input,
106
+ lin_weight,
107
+ target,
108
+ bias,
109
+ self.ignore_index,
110
+ self.beta,
111
+ self.alpha,
112
+ self.compute_nll_loss,
113
+ self.compiled,
114
+ self.gamma,
115
+ )
@@ -1,5 +1,6 @@
1
1
  import platform
2
2
  import sys
3
+ from importlib.metadata import version
3
4
 
4
5
 
5
6
  def print_env_report():
@@ -17,6 +18,11 @@ def print_env_report():
17
18
  print(f"Operating System: {platform.platform()}")
18
19
  print(f"Python version: {sys.version.split()[0]}")
19
20
 
21
+ try:
22
+ print(f"Liger Kernel version: {version('liger-kernel')}")
23
+ except ImportError:
24
+ print("Liger Kernel: Not installed")
25
+
20
26
  try:
21
27
  import torch
22
28
 
@@ -25,9 +31,17 @@ def print_env_report():
25
31
  torch.version.cuda if torch.cuda.is_available() else "Not available"
26
32
  )
27
33
  print(f"CUDA version: {cuda_version}")
34
+ hip_version = (
35
+ torch.version.hip
36
+ if torch.cuda.is_available() and torch.version.hip
37
+ else "Not available"
38
+ )
39
+ print(f"HIP(ROCm) version: {hip_version}")
40
+
28
41
  except ImportError:
29
42
  print("PyTorch: Not installed")
30
43
  print("CUDA version: Unable to query")
44
+ print("HIP(ROCm) version: Unable to query")
31
45
 
32
46
  try:
33
47
  import triton
@@ -43,6 +57,14 @@ def print_env_report():
43
57
  except ImportError:
44
58
  print("Transformers: Not installed")
45
59
 
60
+ try:
61
+ xpu_version = (
62
+ torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
63
+ )
64
+ print(f"XPU version: {xpu_version}")
65
+ except ImportError:
66
+ print("XPU version: Unable to query")
67
+
46
68
 
47
69
  if __name__ == "__main__":
48
70
  print_env_report()