liger-kernel 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/__init__.py +4 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +107 -0
  4. liger_kernel/chunked_loss/dpo_loss.py +135 -0
  5. liger_kernel/chunked_loss/functional.py +9 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
  8. liger_kernel/chunked_loss/orpo_loss.py +113 -0
  9. liger_kernel/chunked_loss/simpo_loss.py +115 -0
  10. liger_kernel/env_report.py +22 -0
  11. liger_kernel/ops/cross_entropy.py +17 -10
  12. liger_kernel/ops/fused_linear_cross_entropy.py +1 -11
  13. liger_kernel/ops/fused_linear_jsd.py +1 -1
  14. liger_kernel/ops/jsd.py +19 -10
  15. liger_kernel/ops/layer_norm.py +6 -1
  16. liger_kernel/ops/qwen2vl_mrope.py +238 -0
  17. liger_kernel/ops/rms_norm.py +6 -1
  18. liger_kernel/ops/utils.py +5 -2
  19. liger_kernel/transformers/__init__.py +1 -0
  20. liger_kernel/transformers/functional.py +128 -11
  21. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  22. liger_kernel/transformers/jsd.py +1 -4
  23. liger_kernel/transformers/model/qwen2_vl.py +43 -17
  24. liger_kernel/transformers/monkey_patch.py +11 -6
  25. liger_kernel/transformers/orpo_trainer.py +171 -0
  26. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  27. liger_kernel/utils.py +13 -0
  28. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/METADATA +80 -123
  29. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/RECORD +33 -20
  30. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/WHEEL +1 -1
  31. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/LICENSE +0 -0
  32. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/NOTICE +0 -0
  33. {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,4 @@
1
+ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
+ from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
4
+ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
@@ -0,0 +1,107 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import (
5
+ LigerFusedLinearPreferenceBase,
6
+ )
7
+
8
+
9
+ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
10
+
11
+ @staticmethod
12
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
+ """
14
+ Paper: https://arxiv.org/pdf/2401.08417
15
+
16
+ Formula:
17
+ L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
18
+
19
+ Where:
20
+ - π_θ(y|x): Policy (model) probability
21
+ - y_w: Chosen sequence
22
+ - y_l: Rejected sequence
23
+ - σ: Sigmoid function
24
+ - β: Temperature parameter
25
+ - E: Expected value over the dataset D
26
+ - D: Dataset of preferences
27
+
28
+ Args:
29
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
30
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
+ full_target (torch.Tensor): Non chunked full target tensor
32
+ beta (float): Weight for the CPO loss
33
+ """
34
+ logits = beta * (chosen_logps - rejected_logps)
35
+ loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
36
+ return loss
37
+
38
+ @staticmethod
39
+ def forward(
40
+ ctx,
41
+ _input,
42
+ weight,
43
+ target,
44
+ bias=None,
45
+ ignore_index=-100,
46
+ beta=0.1,
47
+ alpha=1.0,
48
+ compute_nll_loss=True,
49
+ compiled=True,
50
+ ):
51
+ return LigerFusedLinearPreferenceBase.forward(
52
+ ctx,
53
+ _input,
54
+ weight,
55
+ target,
56
+ bias,
57
+ loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
58
+ ignore_index=ignore_index,
59
+ alpha=alpha,
60
+ beta=beta,
61
+ compute_nll_loss=compute_nll_loss,
62
+ compiled=compiled,
63
+ )
64
+
65
+ @staticmethod
66
+ def backward(ctx, *grad_output):
67
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
68
+ return *grads, None, None, None, None, None
69
+
70
+
71
+ class LigerFusedLinearCPOLoss(torch.nn.Module):
72
+ """
73
+ Fused linear layer with CPO loss.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ ignore_index: int = -100,
79
+ beta: float = 0.1,
80
+ alpha: float = 1.0,
81
+ compute_nll_loss: bool = True,
82
+ compiled: bool = True,
83
+ ):
84
+ """
85
+ Args:
86
+ ignore_index (int): Index to ignore in the loss.
87
+ beta (float): Weight for the odds ratio loss.
88
+ """
89
+ super().__init__()
90
+ self.ignore_index = ignore_index
91
+ self.beta = beta
92
+ self.alpha = alpha
93
+ self.compute_nll_loss = compute_nll_loss
94
+ self.compiled = compiled
95
+
96
+ def forward(self, lin_weight, _input, target, bias=None):
97
+ return LigerFusedLinearCPOFunction.apply(
98
+ _input,
99
+ lin_weight,
100
+ target,
101
+ bias,
102
+ self.ignore_index,
103
+ self.beta,
104
+ self.alpha,
105
+ self.compute_nll_loss,
106
+ self.compiled,
107
+ )
@@ -0,0 +1,135 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import (
5
+ LigerFusedLinearPreferenceBase,
6
+ )
7
+
8
+
9
+ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
10
+
11
+ @staticmethod
12
+ def preference_loss_fn(
13
+ chosen_logps,
14
+ rejected_logps,
15
+ full_target,
16
+ ref_chosen_logps=None,
17
+ ref_rejected_logps=None,
18
+ beta=0.1,
19
+ ):
20
+ """
21
+ Paper: https://arxiv.org/pdf/2305.18290
22
+
23
+ Formula:
24
+ L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
25
+
26
+ Where:
27
+ - π(y|x): Policy (model) probability
28
+ - π_ref(y|x): Reference model probability
29
+ - y_w: Chosen sequence
30
+ - y_l: Rejected sequence
31
+ - β: Weight for the direct preference loss
32
+ - E: Expected value over the dataset
33
+
34
+ Args:
35
+ chosen_logps: Log probabilities of chosen tokens (batch_size,)
36
+ rejected_logps: Log probabilities of rejected tokens (batch_size,)
37
+ full_target: Non chunked full target tensor
38
+ ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
39
+ ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
40
+ beta: Weight for the direct preference loss
41
+ """
42
+
43
+ if ref_chosen_logps is None:
44
+ ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
45
+ if ref_rejected_logps is None:
46
+ ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
47
+
48
+ chosen_logratios = chosen_logps - ref_chosen_logps
49
+ rejected_logratios = rejected_logps - ref_rejected_logps
50
+
51
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
52
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
53
+ return loss
54
+
55
+ @staticmethod
56
+ def forward(
57
+ ctx,
58
+ _input,
59
+ weight,
60
+ target,
61
+ bias=None,
62
+ ref_weight=None,
63
+ ref_bias=None,
64
+ ignore_index=-100,
65
+ beta=0.1,
66
+ compute_nll_loss=True,
67
+ compiled=True,
68
+ use_ref_model=True,
69
+ ):
70
+ return LigerFusedLinearPreferenceBase.forward(
71
+ ctx=ctx,
72
+ _input=_input,
73
+ weight=weight,
74
+ target=target,
75
+ bias=bias,
76
+ loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
77
+ ignore_index=ignore_index,
78
+ beta=beta,
79
+ compute_nll_loss=compute_nll_loss,
80
+ compiled=compiled,
81
+ use_ref_model=use_ref_model,
82
+ ref_weight=ref_weight,
83
+ ref_bias=ref_bias,
84
+ )
85
+
86
+ @staticmethod
87
+ def backward(ctx, *grad_output):
88
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
89
+ return *grads, None, None, None, None, None, None, None
90
+
91
+
92
+ class LigerFusedLinearDPOLoss(torch.nn.Module):
93
+ """
94
+ Fused linear layer with DPO loss.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ ignore_index: int = -100,
100
+ beta: float = 0.1,
101
+ compute_nll_loss: bool = True,
102
+ compiled: bool = True,
103
+ use_ref_model: bool = False,
104
+ ):
105
+ """
106
+ Args:
107
+ ignore_index (int): Index to ignore in the loss.
108
+ beta (float): Weight for the odds ratio loss.
109
+ compute_nll_loss (bool): Whether to compute the NLL loss.
110
+ compiled (bool): Whether to use the torch compiled kernel.
111
+ use_ref_model (bool): Whether to use a reference model for the DPO loss.
112
+ """
113
+ super().__init__()
114
+ self.ignore_index = ignore_index
115
+ self.beta = beta
116
+ self.compute_nll_loss = compute_nll_loss
117
+ self.compiled = compiled
118
+ self.use_ref_model = use_ref_model
119
+
120
+ def forward(
121
+ self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
122
+ ):
123
+ return LigerFusedLinearDPOFunction.apply(
124
+ _input,
125
+ lin_weight,
126
+ target,
127
+ bias,
128
+ ref_weight,
129
+ ref_bias,
130
+ self.ignore_index,
131
+ self.beta,
132
+ self.compute_nll_loss,
133
+ self.compiled,
134
+ self.use_ref_model,
135
+ )
@@ -0,0 +1,9 @@
1
+ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
+ from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
4
+ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
5
+
6
+ liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
7
+ liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
8
+ liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
9
+ liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
@@ -0,0 +1,252 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class LigerFusedLinearDistillationBase(torch.autograd.Function):
9
+
10
+ @abstractmethod
11
+ def distillation_loss_fn(student_logits, teacher_logits, temperature):
12
+ """
13
+ Compute distillation loss.
14
+ Args:
15
+ student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
16
+ teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
17
+ """
18
+ raise NotImplementedError("Distillation loss function must be implemented.")
19
+
20
+ @staticmethod
21
+ def chunk_forward(
22
+ student_input_chunk,
23
+ student_weight,
24
+ teacher_input_chunk,
25
+ teacher_weight,
26
+ target_chunk,
27
+ student_bias=None,
28
+ teacher_bias=None,
29
+ ignore_index=-100,
30
+ compute_ce_loss=True,
31
+ ):
32
+ # Student
33
+ student_logits_chunk = student_input_chunk @ student_weight.t()
34
+ if student_bias is not None:
35
+ student_logits_chunk += student_bias
36
+ student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)
37
+
38
+ # Teacher
39
+ with torch.no_grad():
40
+ teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
41
+ if teacher_bias is not None:
42
+ teacher_logits_chunk += teacher_bias
43
+
44
+ # The hard/task loss
45
+ ce_loss = 0.0
46
+ if compute_ce_loss:
47
+ ce_loss = F.nll_loss(
48
+ student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
49
+ target_chunk.view(-1),
50
+ reduction="sum",
51
+ ignore_index=ignore_index,
52
+ )
53
+
54
+ return student_logits_chunk, teacher_logits_chunk, ce_loss
55
+
56
+ @staticmethod
57
+ def _compute_loss(
58
+ student_input_chunk,
59
+ student_weight,
60
+ teacher_input_chunk,
61
+ teacher_weight,
62
+ target_chunk,
63
+ student_bias=None,
64
+ teacher_bias=None,
65
+ distillation_loss_fn=None,
66
+ full_target=None,
67
+ ignore_index=-100,
68
+ temperature=1.0,
69
+ weight_hard_loss=0.5,
70
+ weight_soft_loss=0.5,
71
+ compute_ce_loss=True,
72
+ **loss_kwargs,
73
+ ):
74
+ """
75
+ Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
76
+ Args:
77
+ distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
78
+ student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
79
+ student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
80
+ teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
81
+ teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
82
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
83
+ student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
84
+ teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
85
+ full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
86
+ ignore_index (int): Index to ignore for loss computation.
87
+ weight_hard_loss (float): Weight for hard loss.
88
+ weight_soft_loss (float): Weight for soft loss.
89
+ compute_ce_loss (bool): Whether to compute CE loss.
90
+ loss_kwargs (dict): Additional arguments for the loss function.
91
+ """
92
+ student_logits_chunk, teacher_logits_chunk, hard_loss = (
93
+ LigerFusedLinearDistillationBase.chunk_forward(
94
+ student_input_chunk,
95
+ student_weight,
96
+ teacher_input_chunk,
97
+ teacher_weight,
98
+ target_chunk,
99
+ student_bias=student_bias,
100
+ teacher_bias=teacher_bias,
101
+ ignore_index=ignore_index,
102
+ compute_ce_loss=compute_ce_loss,
103
+ )
104
+ )
105
+
106
+ hard_loss /= full_target.shape[0]
107
+
108
+ soft_loss = distillation_loss_fn(
109
+ student_logits_chunk, teacher_logits_chunk, temperature
110
+ )
111
+ soft_loss /= full_target.shape[0]
112
+
113
+ loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
114
+ return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
115
+
116
+ @staticmethod
117
+ def forward(
118
+ ctx,
119
+ student_input,
120
+ student_weight,
121
+ teacher_input,
122
+ teacher_weight,
123
+ target,
124
+ student_bias=None,
125
+ teacher_bias=None,
126
+ loss_fn=None,
127
+ chunk_size=1024,
128
+ ignore_index=-100,
129
+ weight_hard_loss=0.5,
130
+ weight_soft_loss=0.5,
131
+ compute_ce_loss=True,
132
+ temperature=1.0,
133
+ compiled=True,
134
+ **loss_kwargs,
135
+ ):
136
+ """
137
+ Base class for fused linear layer with distillation loss.
138
+ Only need to compute gradients for student model.
139
+
140
+ Args:
141
+ student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
142
+ student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
143
+ teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
144
+ teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
145
+ target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
146
+ student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
147
+ teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
148
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
149
+ chunk_size (int): Size of a chunk.
150
+ compute_ce_loss (bool): Whether to compute CE loss.
151
+ ignore_index (int): Index to ignore for loss computation.
152
+ weight_hard_loss (float): Weight for hard/task loss.
153
+ weight_soft_loss (float): Weight for soft/distillation loss.
154
+ compiled (bool): Whether to use torch compile for chunk accumulation.
155
+ loss_kwargs (dict): Other possible arguments that a loss function might need
156
+ """
157
+ CHUNK_SIZE = chunk_size
158
+ grad_weight = torch.zeros_like(student_weight)
159
+ grad_inputs = []
160
+ grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
161
+ loss_acc = torch.zeros((), device=student_input.device)
162
+
163
+ loss_func_to_call = partial(
164
+ LigerFusedLinearDistillationBase._compute_loss,
165
+ distillation_loss_fn=loss_fn,
166
+ full_target=target,
167
+ ignore_index=ignore_index,
168
+ weight_hard_loss=weight_hard_loss,
169
+ weight_soft_loss=weight_soft_loss,
170
+ compute_ce_loss=compute_ce_loss,
171
+ temperature=temperature,
172
+ **loss_kwargs,
173
+ )
174
+
175
+ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
176
+ if student_bias is not None:
177
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
178
+ chunk_loss,
179
+ (
180
+ chunk_soft_loss,
181
+ chunk_hard_loss,
182
+ chunk_student_logits,
183
+ chunk_teacher_logits,
184
+ ),
185
+ ) = torch.func.grad_and_value(
186
+ loss_func_to_call, argnums=(0, 1, 5), has_aux=True
187
+ )(
188
+ student_input_chunk,
189
+ student_weight,
190
+ teacher_input_chunk,
191
+ teacher_weight,
192
+ target_chunk,
193
+ student_bias,
194
+ teacher_bias,
195
+ )
196
+ grad_bias.add_(chunk_grad_bias)
197
+ else:
198
+ (chunk_grad_input, chunk_grad_weight), (
199
+ chunk_loss,
200
+ (
201
+ chunk_soft_loss,
202
+ chunk_hard_loss,
203
+ chunk_student_logits,
204
+ chunk_teacher_logits,
205
+ ),
206
+ ) = torch.func.grad_and_value(
207
+ loss_func_to_call, argnums=(0, 1), has_aux=True
208
+ )(
209
+ student_input_chunk,
210
+ student_weight,
211
+ teacher_input_chunk,
212
+ teacher_weight,
213
+ target_chunk,
214
+ student_bias,
215
+ teacher_bias,
216
+ )
217
+ grad_weight.add_(chunk_grad_weight)
218
+ loss_acc.add_(chunk_loss)
219
+ return chunk_grad_input
220
+
221
+ if compiled:
222
+ accumulate_chunk = torch.compile(accumulate_chunk)
223
+
224
+ num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
225
+ _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
226
+ _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
227
+ _target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)
228
+
229
+ for student_input_chunk, teacher_input_chunk, target_chunk in zip(
230
+ _student_input_chunks, _teacher_input_chunks, _target_chunks
231
+ ):
232
+ grad_input = accumulate_chunk(
233
+ student_input_chunk, teacher_input_chunk, target_chunk
234
+ )
235
+ grad_inputs.append(grad_input)
236
+
237
+ ctx.save_for_backward(
238
+ torch.cat(grad_inputs, dim=0),
239
+ grad_weight,
240
+ grad_bias,
241
+ )
242
+ return loss_acc
243
+
244
+ @staticmethod
245
+ def backward(ctx, grad_output):
246
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
247
+ if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
248
+ grad_input = grad_input * grad_output
249
+ grad_weight = grad_weight * grad_output
250
+ grad_bias = grad_bias * grad_output if grad_bias is not None else None
251
+
252
+ return grad_input, grad_weight, None, grad_bias