liger-kernel 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +51 -11
- liger_kernel/chunked_loss/dpo_loss.py +30 -4
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +20 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- liger_kernel/chunked_loss/grpo_loss.py +137 -61
- liger_kernel/chunked_loss/jsd_loss.py +43 -13
- liger_kernel/chunked_loss/kto_loss.py +50 -12
- liger_kernel/chunked_loss/orpo_loss.py +37 -5
- liger_kernel/chunked_loss/simpo_loss.py +47 -11
- liger_kernel/ops/cross_entropy.py +7 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +30 -11
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/transformers/__init__.py +4 -0
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/model/gemma.py +8 -16
- liger_kernel/transformers/model/gemma2.py +7 -16
- liger_kernel/transformers/model/llama.py +8 -15
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +57 -0
- liger_kernel/transformers/model/mistral.py +9 -10
- liger_kernel/transformers/model/mixtral.py +8 -15
- liger_kernel/transformers/model/mllama.py +8 -15
- liger_kernel/transformers/model/olmo2.py +8 -16
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +8 -15
- liger_kernel/transformers/model/qwen2.py +8 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +204 -0
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +286 -12
- liger_kernel/utils.py +1 -3
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +11 -7
- liger_kernel-0.5.6.dist-info/RECORD +80 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -213
- liger_kernel-0.5.4.dist-info/RECORD +0 -74
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
|
@@ -39,8 +39,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
39
39
|
|
|
40
40
|
return loss, chosen_rewards, rejected_rewards
|
|
41
41
|
|
|
42
|
-
@
|
|
42
|
+
@classmethod
|
|
43
43
|
def forward(
|
|
44
|
+
cls,
|
|
44
45
|
ctx,
|
|
45
46
|
_input,
|
|
46
47
|
weight,
|
|
@@ -52,27 +53,48 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
52
53
|
label_smoothing=0.0,
|
|
53
54
|
compute_nll_loss=True,
|
|
54
55
|
compiled=True,
|
|
56
|
+
average_log_prob=False,
|
|
57
|
+
chunk_size=1,
|
|
55
58
|
):
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
59
|
+
"""
|
|
60
|
+
Fused linear layer with CPO loss.
|
|
61
|
+
Args:
|
|
62
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
63
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
64
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
65
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
66
|
+
ignore_index (int): Index to ignore in loss computation
|
|
67
|
+
beta (float): Weight for the odds ratio loss
|
|
68
|
+
alpha (float): Weight for the alpha parameter
|
|
69
|
+
label_smoothing (float): Label smoothing factor
|
|
70
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
71
|
+
compiled (bool): Whether to use torch compile
|
|
72
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
73
|
+
chunk_size (int): Size of chunks for processing.
|
|
74
|
+
Returns:
|
|
75
|
+
torch.Tensor: Computed loss
|
|
76
|
+
"""
|
|
77
|
+
return super().forward(
|
|
78
|
+
cls=cls,
|
|
79
|
+
ctx=ctx,
|
|
80
|
+
_input=_input,
|
|
81
|
+
weight=weight,
|
|
82
|
+
target=target,
|
|
83
|
+
bias=bias,
|
|
63
84
|
ignore_index=ignore_index,
|
|
64
85
|
alpha=alpha,
|
|
65
86
|
beta=beta,
|
|
66
87
|
label_smoothing=label_smoothing,
|
|
67
88
|
compute_nll_loss=compute_nll_loss,
|
|
68
|
-
average_log_prob=
|
|
89
|
+
average_log_prob=average_log_prob,
|
|
69
90
|
compiled=compiled,
|
|
91
|
+
chunk_size=chunk_size,
|
|
70
92
|
)
|
|
71
93
|
|
|
72
94
|
@staticmethod
|
|
73
95
|
def backward(ctx, *grad_output):
|
|
74
96
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
75
|
-
return *grads, None, None, None, None, None, None
|
|
97
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
76
98
|
|
|
77
99
|
|
|
78
100
|
class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
@@ -88,11 +110,19 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
88
110
|
label_smoothing: float = 0.0,
|
|
89
111
|
compute_nll_loss: bool = True,
|
|
90
112
|
compiled: bool = True,
|
|
113
|
+
average_log_prob: bool = False,
|
|
114
|
+
chunk_size: int = 1,
|
|
91
115
|
):
|
|
92
116
|
"""
|
|
93
117
|
Args:
|
|
94
118
|
ignore_index (int): Index to ignore in the loss.
|
|
95
119
|
beta (float): Weight for the odds ratio loss.
|
|
120
|
+
alpha (float): Weight for the alpha parameter.
|
|
121
|
+
label_smoothing (float): Label smoothing factor.
|
|
122
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
123
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
124
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
125
|
+
chunk_size (int): Size of chunks for processing.
|
|
96
126
|
"""
|
|
97
127
|
super().__init__()
|
|
98
128
|
self.ignore_index = ignore_index
|
|
@@ -101,8 +131,16 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
101
131
|
self.label_smoothing = label_smoothing
|
|
102
132
|
self.compute_nll_loss = compute_nll_loss
|
|
103
133
|
self.compiled = compiled
|
|
134
|
+
self.average_log_prob = average_log_prob
|
|
135
|
+
self.chunk_size = chunk_size
|
|
104
136
|
|
|
105
|
-
def forward(
|
|
137
|
+
def forward(
|
|
138
|
+
self,
|
|
139
|
+
lin_weight,
|
|
140
|
+
_input,
|
|
141
|
+
target,
|
|
142
|
+
bias=None,
|
|
143
|
+
):
|
|
106
144
|
return LigerFusedLinearCPOFunction.apply(
|
|
107
145
|
_input,
|
|
108
146
|
lin_weight,
|
|
@@ -114,4 +152,6 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
114
152
|
self.label_smoothing,
|
|
115
153
|
self.compute_nll_loss,
|
|
116
154
|
self.compiled,
|
|
155
|
+
self.average_log_prob,
|
|
156
|
+
self.chunk_size,
|
|
117
157
|
)
|
|
@@ -52,8 +52,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
52
52
|
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
53
53
|
return loss, chosen_rewards, rejected_rewards
|
|
54
54
|
|
|
55
|
-
@
|
|
55
|
+
@classmethod
|
|
56
56
|
def forward(
|
|
57
|
+
cls,
|
|
57
58
|
ctx,
|
|
58
59
|
_input,
|
|
59
60
|
weight,
|
|
@@ -67,14 +68,34 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
67
68
|
compute_nll_loss=False,
|
|
68
69
|
compiled=True,
|
|
69
70
|
use_ref_model=True,
|
|
71
|
+
chunk_size=1,
|
|
70
72
|
):
|
|
71
|
-
|
|
73
|
+
"""
|
|
74
|
+
Fused linear layer with DPO loss.
|
|
75
|
+
Args:
|
|
76
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
77
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
78
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
79
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
80
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
81
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
82
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
83
|
+
ignore_index (int): Index to ignore in loss computation
|
|
84
|
+
beta (float): Weight for the odds ratio loss
|
|
85
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
86
|
+
compiled (bool): Whether to use torch compile
|
|
87
|
+
use_ref_model (bool): Whether to use a reference model
|
|
88
|
+
chunk_size (int): Size of chunks for processing.
|
|
89
|
+
Returns:
|
|
90
|
+
torch.Tensor: Computed loss
|
|
91
|
+
"""
|
|
92
|
+
return super().forward(
|
|
93
|
+
cls=cls,
|
|
72
94
|
ctx=ctx,
|
|
73
95
|
_input=_input,
|
|
74
96
|
weight=weight,
|
|
75
97
|
target=target,
|
|
76
98
|
bias=bias,
|
|
77
|
-
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
|
|
78
99
|
ignore_index=ignore_index,
|
|
79
100
|
beta=beta,
|
|
80
101
|
compute_nll_loss=compute_nll_loss,
|
|
@@ -83,12 +104,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
83
104
|
ref_input=ref_input,
|
|
84
105
|
ref_weight=ref_weight,
|
|
85
106
|
ref_bias=ref_bias,
|
|
107
|
+
chunk_size=chunk_size,
|
|
86
108
|
)
|
|
87
109
|
|
|
88
110
|
@staticmethod
|
|
89
111
|
def backward(ctx, *grad_output):
|
|
90
112
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
91
|
-
return *grads, None, None, None, None, None, None, None, None
|
|
113
|
+
return *grads, None, None, None, None, None, None, None, None, None
|
|
92
114
|
|
|
93
115
|
|
|
94
116
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -103,6 +125,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
103
125
|
compute_nll_loss: bool = False,
|
|
104
126
|
compiled: bool = True,
|
|
105
127
|
use_ref_model: bool = True,
|
|
128
|
+
chunk_size: int = 1,
|
|
106
129
|
):
|
|
107
130
|
"""
|
|
108
131
|
Args:
|
|
@@ -111,6 +134,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
111
134
|
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
112
135
|
compiled (bool): Whether to use the torch compiled kernel.
|
|
113
136
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
137
|
+
chunk_size (int): Size of chunks for processing.
|
|
114
138
|
"""
|
|
115
139
|
super().__init__()
|
|
116
140
|
self.ignore_index = ignore_index
|
|
@@ -118,6 +142,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
118
142
|
self.compute_nll_loss = compute_nll_loss
|
|
119
143
|
self.compiled = compiled
|
|
120
144
|
self.use_ref_model = use_ref_model
|
|
145
|
+
self.chunk_size = chunk_size
|
|
121
146
|
|
|
122
147
|
def forward(
|
|
123
148
|
self,
|
|
@@ -142,4 +167,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
142
167
|
self.compute_nll_loss,
|
|
143
168
|
self.compiled,
|
|
144
169
|
self.use_ref_model,
|
|
170
|
+
self.chunk_size,
|
|
145
171
|
)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
|
3
4
|
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
|
4
5
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
|
5
6
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
|
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
|
11
12
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
12
13
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
13
14
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
15
|
+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
|
|
@@ -115,9 +115,24 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
115
115
|
student_logits_chunk /= temperature
|
|
116
116
|
teacher_logits_chunk /= temperature
|
|
117
117
|
|
|
118
|
+
# If the teacher and student token size is different, pad student logits to match the teacher's.
|
|
119
|
+
# This only applies to cases where they share exactly the same vocab and tokenizer just
|
|
120
|
+
# that teacher logit is padded for some training efficiency such as
|
|
121
|
+
# https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
|
|
122
|
+
teacher_vocab_size = teacher_weight.shape[0]
|
|
123
|
+
student_vocab_size = student_weight.shape[0]
|
|
124
|
+
if teacher_vocab_size > student_vocab_size:
|
|
125
|
+
pad_size = teacher_vocab_size - student_vocab_size
|
|
126
|
+
pad_tensor = torch.zeros(
|
|
127
|
+
(*student_logits_chunk.shape[:-1], pad_size),
|
|
128
|
+
dtype=student_logits_chunk.dtype,
|
|
129
|
+
device=student_logits_chunk.device,
|
|
130
|
+
)
|
|
131
|
+
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
|
|
132
|
+
|
|
118
133
|
hard_loss /= full_target.shape[0]
|
|
119
134
|
|
|
120
|
-
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
|
|
135
|
+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
|
|
121
136
|
soft_loss /= full_target.shape[0]
|
|
122
137
|
|
|
123
138
|
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
|
@@ -125,6 +140,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
125
140
|
|
|
126
141
|
@staticmethod
|
|
127
142
|
def forward(
|
|
143
|
+
cls,
|
|
128
144
|
ctx,
|
|
129
145
|
student_input,
|
|
130
146
|
student_weight,
|
|
@@ -133,7 +149,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
133
149
|
target,
|
|
134
150
|
student_bias=None,
|
|
135
151
|
teacher_bias=None,
|
|
136
|
-
loss_fn=None,
|
|
137
152
|
chunk_size=1024,
|
|
138
153
|
ignore_index=-100,
|
|
139
154
|
weight_hard_loss=0.5,
|
|
@@ -175,14 +190,14 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
175
190
|
|
|
176
191
|
loss_func_to_call = partial(
|
|
177
192
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
178
|
-
distillation_loss_fn=
|
|
193
|
+
distillation_loss_fn=cls.distillation_loss_fn,
|
|
179
194
|
full_target=target,
|
|
180
195
|
ignore_index=ignore_index,
|
|
181
196
|
weight_hard_loss=weight_hard_loss,
|
|
182
197
|
weight_soft_loss=weight_soft_loss,
|
|
183
|
-
beta=beta,
|
|
184
198
|
compute_ce_loss=compute_ce_loss,
|
|
185
199
|
temperature=temperature,
|
|
200
|
+
beta=beta,
|
|
186
201
|
**loss_kwargs,
|
|
187
202
|
)
|
|
188
203
|
|
|
@@ -263,4 +278,4 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
263
278
|
grad_weight = grad_weight * grad_output
|
|
264
279
|
grad_bias = grad_bias * grad_output if grad_bias is not None else None
|
|
265
280
|
|
|
266
|
-
return grad_input, grad_weight, None, grad_bias
|
|
281
|
+
return grad_input, grad_weight, None, None, None, grad_bias
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch._dynamo.config
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def ppo_loss_fn(*args, **kwargs):
|
|
12
|
+
"""
|
|
13
|
+
To be extended by subclasses.
|
|
14
|
+
"""
|
|
15
|
+
raise NotImplementedError("PPO loss function must be implemented.")
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def forward(
|
|
19
|
+
cls,
|
|
20
|
+
ctx,
|
|
21
|
+
_input,
|
|
22
|
+
weight,
|
|
23
|
+
selected_token_ids,
|
|
24
|
+
attention_mask,
|
|
25
|
+
advantages,
|
|
26
|
+
bias=None,
|
|
27
|
+
ref_per_token_logps=None,
|
|
28
|
+
old_per_token_logps=None,
|
|
29
|
+
ref_input=None,
|
|
30
|
+
ref_weight=None,
|
|
31
|
+
ref_bias=None,
|
|
32
|
+
epsilon_low=0.2,
|
|
33
|
+
epsilon_high=0.2,
|
|
34
|
+
beta=0.04,
|
|
35
|
+
temperature=1.0,
|
|
36
|
+
compiled=True,
|
|
37
|
+
use_ref_model=False,
|
|
38
|
+
chunk_size=1,
|
|
39
|
+
):
|
|
40
|
+
# TODO: check torch compile matmul
|
|
41
|
+
"""Chunked forward pass for PPO loss computation.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
cls: The class
|
|
45
|
+
ctx: Context for backward
|
|
46
|
+
_input: Input tensor
|
|
47
|
+
weight: Weight tensor
|
|
48
|
+
selected_token_ids: Selected token ids tensor
|
|
49
|
+
attention_mask: Attention mask tensor
|
|
50
|
+
advantages: Advantages tensor
|
|
51
|
+
bias: Bias tensor
|
|
52
|
+
ref_per_token_logps: Reference model log probs per token tensor
|
|
53
|
+
old_per_token_logps: Old per token log probabilities tensor
|
|
54
|
+
ref_input: Reference model input tensor
|
|
55
|
+
ref_weight: Reference model weight tensor
|
|
56
|
+
ref_bias: Reference model bias tensor
|
|
57
|
+
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
58
|
+
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
59
|
+
beta: Weight for the KL penalty
|
|
60
|
+
temperature: Temperature for the logits
|
|
61
|
+
compiled: Whether to use torch compile
|
|
62
|
+
use_ref_model: Whether to use a reference model
|
|
63
|
+
chunk_size: Size of chunks for processing in other loss modules
|
|
64
|
+
"""
|
|
65
|
+
if use_ref_model:
|
|
66
|
+
assert ref_per_token_logps is not None or ref_input is not None, (
|
|
67
|
+
"If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
|
|
68
|
+
)
|
|
69
|
+
if ref_per_token_logps is not None and ref_input is not None:
|
|
70
|
+
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
|
71
|
+
# Initialize accumulators
|
|
72
|
+
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
|
73
|
+
grad_weight = torch.zeros_like(weight) # [V, H]
|
|
74
|
+
grad_inputs = []
|
|
75
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
|
76
|
+
aggregated_metrics = []
|
|
77
|
+
|
|
78
|
+
# Create a partial function with fixed arguments
|
|
79
|
+
compute_loss = partial(
|
|
80
|
+
LigerFusedLinearPPOBase._compute_chunk_loss,
|
|
81
|
+
ref_weight=ref_weight,
|
|
82
|
+
ref_bias=ref_bias,
|
|
83
|
+
full_attention_mask=attention_mask,
|
|
84
|
+
epsilon_low=epsilon_low,
|
|
85
|
+
epsilon_high=epsilon_high,
|
|
86
|
+
beta=beta,
|
|
87
|
+
temperature=temperature,
|
|
88
|
+
use_ref_model=use_ref_model,
|
|
89
|
+
ppo_loss_fn=cls.ppo_loss_fn,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def fused_fwd_bwd(
|
|
93
|
+
input_chunk,
|
|
94
|
+
selected_token_ids_chunk,
|
|
95
|
+
attention_mask_chunk,
|
|
96
|
+
advantages_chunk,
|
|
97
|
+
ref_per_token_logps_chunk,
|
|
98
|
+
old_per_token_logps_chunk,
|
|
99
|
+
ref_input_chunk,
|
|
100
|
+
):
|
|
101
|
+
"""Fused forward and backward for a chunk."""
|
|
102
|
+
argnums = (0, 1, 5) if bias is not None else (0, 1)
|
|
103
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
|
104
|
+
input_chunk, # arg 0
|
|
105
|
+
weight, # arg 1
|
|
106
|
+
selected_token_ids_chunk, # arg 2
|
|
107
|
+
attention_mask_chunk, # arg 3
|
|
108
|
+
advantages_chunk, # arg 4
|
|
109
|
+
bias, # arg 5
|
|
110
|
+
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
|
|
111
|
+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
|
|
112
|
+
ref_input_chunk=ref_input_chunk, # arg 8
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def accumulate_chunk(
|
|
116
|
+
input_chunk,
|
|
117
|
+
selected_token_ids_chunk,
|
|
118
|
+
attention_mask_chunk,
|
|
119
|
+
advantages_chunk,
|
|
120
|
+
ref_per_token_logps_chunk=None,
|
|
121
|
+
old_per_token_logps_chunk=None,
|
|
122
|
+
ref_input_chunk=None,
|
|
123
|
+
):
|
|
124
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
125
|
+
input_chunk,
|
|
126
|
+
selected_token_ids_chunk,
|
|
127
|
+
attention_mask_chunk,
|
|
128
|
+
advantages_chunk,
|
|
129
|
+
ref_per_token_logps_chunk,
|
|
130
|
+
old_per_token_logps_chunk,
|
|
131
|
+
ref_input_chunk,
|
|
132
|
+
)
|
|
133
|
+
if bias is not None:
|
|
134
|
+
grad_bias.add_(chunk_grad_bias[0])
|
|
135
|
+
|
|
136
|
+
# Accumulate gradients and loss
|
|
137
|
+
grad_weight.add_(chunk_grad_weight)
|
|
138
|
+
grad_inputs.append(chunk_grad_input)
|
|
139
|
+
loss_acc.add_(chunk_loss)
|
|
140
|
+
# Initialize storage for metrics on first chunk
|
|
141
|
+
if len(aggregated_metrics) == 0:
|
|
142
|
+
for metric in chunk_metrics:
|
|
143
|
+
if metric.ndim == 0:
|
|
144
|
+
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
|
145
|
+
else:
|
|
146
|
+
aggregated_metrics.append([])
|
|
147
|
+
|
|
148
|
+
# Accumulate metrics
|
|
149
|
+
for i, metric in enumerate(chunk_metrics):
|
|
150
|
+
if metric.ndim == 0:
|
|
151
|
+
aggregated_metrics[i].add_(metric)
|
|
152
|
+
else:
|
|
153
|
+
aggregated_metrics[i].append(metric)
|
|
154
|
+
|
|
155
|
+
if compiled:
|
|
156
|
+
# TODO: Figure out what is better to compile here
|
|
157
|
+
# accumulate_chunk = torch.compile(accumulate_chunk)
|
|
158
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
159
|
+
|
|
160
|
+
# Process input in chunks based on chunk_size
|
|
161
|
+
chunks = max(1, _input.shape[0] // chunk_size)
|
|
162
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
163
|
+
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
|
|
164
|
+
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
|
165
|
+
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
|
|
166
|
+
_ref_per_token_logps_chunks = (
|
|
167
|
+
torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
|
|
168
|
+
if use_ref_model and ref_per_token_logps is not None
|
|
169
|
+
else [None] * chunks
|
|
170
|
+
)
|
|
171
|
+
_old_per_token_logps_chunks = (
|
|
172
|
+
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
|
|
173
|
+
if old_per_token_logps is not None
|
|
174
|
+
else [None] * chunks
|
|
175
|
+
)
|
|
176
|
+
# if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
|
|
177
|
+
_ref_input_chunks = (
|
|
178
|
+
torch.chunk(ref_input, chunks=chunks, dim=0)
|
|
179
|
+
if use_ref_model and ref_per_token_logps is None
|
|
180
|
+
else [None] * chunks
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
for (
|
|
184
|
+
input_chunk,
|
|
185
|
+
selected_token_ids_chunk,
|
|
186
|
+
attention_mask_chunk,
|
|
187
|
+
advantages_chunk,
|
|
188
|
+
ref_per_token_logps_chunk,
|
|
189
|
+
old_per_token_logps_chunk,
|
|
190
|
+
ref_input_chunk,
|
|
191
|
+
) in zip(
|
|
192
|
+
_input_chunks,
|
|
193
|
+
_selected_token_ids_chunks,
|
|
194
|
+
_attention_mask_chunks,
|
|
195
|
+
_advantages_chunks,
|
|
196
|
+
_ref_per_token_logps_chunks,
|
|
197
|
+
_old_per_token_logps_chunks,
|
|
198
|
+
_ref_input_chunks,
|
|
199
|
+
):
|
|
200
|
+
# Mark dynamic dimensions
|
|
201
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
202
|
+
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
|
|
203
|
+
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
|
204
|
+
if ref_per_token_logps_chunk is not None:
|
|
205
|
+
torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
|
|
206
|
+
if ref_input_chunk is not None:
|
|
207
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
|
208
|
+
if old_per_token_logps_chunk is not None:
|
|
209
|
+
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
|
|
210
|
+
|
|
211
|
+
accumulate_chunk(
|
|
212
|
+
input_chunk,
|
|
213
|
+
selected_token_ids_chunk,
|
|
214
|
+
attention_mask_chunk,
|
|
215
|
+
advantages_chunk,
|
|
216
|
+
ref_per_token_logps_chunk,
|
|
217
|
+
old_per_token_logps_chunk,
|
|
218
|
+
ref_input_chunk,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Combine gradients
|
|
222
|
+
grad_input = torch.cat(grad_inputs, dim=0)
|
|
223
|
+
|
|
224
|
+
# Save for backward
|
|
225
|
+
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
|
226
|
+
|
|
227
|
+
# Finalize metrics
|
|
228
|
+
final_metrics = []
|
|
229
|
+
for metric in aggregated_metrics:
|
|
230
|
+
if isinstance(metric, list):
|
|
231
|
+
final_metrics.append(torch.cat(metric, dim=0))
|
|
232
|
+
else:
|
|
233
|
+
final_metrics.append(metric)
|
|
234
|
+
|
|
235
|
+
return loss_acc, tuple(final_metrics)
|
|
236
|
+
|
|
237
|
+
@staticmethod
|
|
238
|
+
def _compute_chunk_loss(
|
|
239
|
+
input_chunk,
|
|
240
|
+
weight,
|
|
241
|
+
selected_token_ids_chunk,
|
|
242
|
+
attention_mask_chunk,
|
|
243
|
+
advantages_chunk,
|
|
244
|
+
bias=None,
|
|
245
|
+
ref_per_token_logps_chunk=None,
|
|
246
|
+
old_per_token_logps_chunk=None,
|
|
247
|
+
ref_input_chunk=None,
|
|
248
|
+
ref_weight=None,
|
|
249
|
+
ref_bias=None,
|
|
250
|
+
full_attention_mask=None,
|
|
251
|
+
epsilon_low=0.2,
|
|
252
|
+
epsilon_high=0.2,
|
|
253
|
+
beta=0.04,
|
|
254
|
+
temperature=1.0,
|
|
255
|
+
use_ref_model=False,
|
|
256
|
+
ppo_loss_fn=None,
|
|
257
|
+
):
|
|
258
|
+
"""Compute loss for a single chunk."""
|
|
259
|
+
# Get policy log probabilities using chunk_forward
|
|
260
|
+
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
|
|
261
|
+
|
|
262
|
+
# Get reference log probabilities if needed
|
|
263
|
+
ref_log_probs = None
|
|
264
|
+
if use_ref_model and ref_per_token_logps_chunk is None:
|
|
265
|
+
with torch.no_grad():
|
|
266
|
+
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
|
|
267
|
+
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Compute chunk loss and metrics using the provided loss function
|
|
271
|
+
chunk_loss, chunk_metrics = ppo_loss_fn(
|
|
272
|
+
log_probs=log_probs,
|
|
273
|
+
selected_token_ids=selected_token_ids_chunk,
|
|
274
|
+
attention_mask=attention_mask_chunk,
|
|
275
|
+
advantages=advantages_chunk,
|
|
276
|
+
full_attention_mask=full_attention_mask,
|
|
277
|
+
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
|
|
278
|
+
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
|
|
279
|
+
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
|
|
280
|
+
epsilon_low=epsilon_low,
|
|
281
|
+
epsilon_high=epsilon_high,
|
|
282
|
+
beta=beta,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return chunk_loss, chunk_metrics
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
|
|
289
|
+
"""Forward pass computation for a single chunk without explicit reshaping."""
|
|
290
|
+
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
|
291
|
+
logits = torch.matmul(input_chunk, weight.t())
|
|
292
|
+
if bias is not None:
|
|
293
|
+
logits = logits + bias # Broadcasts bias to [B, T, V]
|
|
294
|
+
if temperature != 1.0:
|
|
295
|
+
logits = logits / temperature
|
|
296
|
+
|
|
297
|
+
# Compute log probabilities using softmax over the last dimension
|
|
298
|
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
|
299
|
+
|
|
300
|
+
return log_probs, logits
|
|
301
|
+
|
|
302
|
+
@staticmethod
|
|
303
|
+
def backward(ctx, grad_output, *grad_metrics):
|
|
304
|
+
"""Backward pass for PPO loss."""
|
|
305
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
306
|
+
if grad_output != 1.0:
|
|
307
|
+
grad_input = grad_input * grad_output
|
|
308
|
+
grad_weight = grad_weight * grad_output
|
|
309
|
+
if grad_bias is not None:
|
|
310
|
+
grad_bias = grad_bias * grad_output
|
|
311
|
+
|
|
312
|
+
return (
|
|
313
|
+
grad_input,
|
|
314
|
+
grad_weight,
|
|
315
|
+
None, # grad_selected_token_ids
|
|
316
|
+
None, # grad_attention_mask
|
|
317
|
+
None, # grad_advantages
|
|
318
|
+
grad_bias,
|
|
319
|
+
None, # grad_ref_per_token_logps
|
|
320
|
+
None, # grad_old_per_token_logps
|
|
321
|
+
None, # grad_ref_input
|
|
322
|
+
None, # grad_ref_weight
|
|
323
|
+
None, # grad_ref_bias
|
|
324
|
+
None, # grad_epsilon_low
|
|
325
|
+
None, # grad_epsilon_high
|
|
326
|
+
None, # grad_beta
|
|
327
|
+
None, # grad_temperature
|
|
328
|
+
None, # grad_compiled
|
|
329
|
+
None, # grad_use_ref_model
|
|
330
|
+
None, # grad_chunk_size
|
|
331
|
+
)
|
|
@@ -16,12 +16,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
16
16
|
|
|
17
17
|
@staticmethod
|
|
18
18
|
def forward(
|
|
19
|
+
cls,
|
|
19
20
|
ctx,
|
|
20
21
|
_input,
|
|
21
22
|
weight,
|
|
22
23
|
target,
|
|
23
24
|
bias=None,
|
|
24
|
-
loss_fn=None,
|
|
25
25
|
chunk_size=1,
|
|
26
26
|
ignore_index=-100,
|
|
27
27
|
alpha=1.0,
|
|
@@ -89,7 +89,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
89
89
|
|
|
90
90
|
compute_loss = partial(
|
|
91
91
|
LigerFusedLinearPreferenceBase._compute_loss,
|
|
92
|
-
preference_loss_fn=
|
|
92
|
+
preference_loss_fn=cls.preference_loss_fn,
|
|
93
93
|
ignore_index=ignore_index,
|
|
94
94
|
alpha=alpha,
|
|
95
95
|
beta=beta,
|