liger-kernel 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cpo_loss.py +51 -11
- liger_kernel/chunked_loss/dpo_loss.py +30 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
- liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- liger_kernel/chunked_loss/fused_linear_rlhf.py +240 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- liger_kernel/chunked_loss/grpo_loss.py +194 -0
- liger_kernel/chunked_loss/jsd_loss.py +31 -6
- liger_kernel/chunked_loss/kto_loss.py +53 -15
- liger_kernel/chunked_loss/orpo_loss.py +37 -5
- liger_kernel/chunked_loss/simpo_loss.py +47 -11
- liger_kernel/ops/cross_entropy.py +7 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
- liger_kernel/ops/fused_linear_jsd.py +3 -3
- liger_kernel/ops/jsd.py +3 -3
- liger_kernel/ops/layer_norm.py +20 -7
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +1 -2
- liger_kernel/transformers/__init__.py +4 -0
- liger_kernel/transformers/cross_entropy.py +3 -3
- liger_kernel/transformers/functional.py +17 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
- liger_kernel/transformers/group_norm.py +6 -6
- liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +205 -0
- liger_kernel/transformers/monkey_patch.py +239 -27
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/utils.py +48 -1
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/METADATA +19 -4
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/RECORD +35 -29
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/top_level.txt +0 -0
|
@@ -7,10 +7,10 @@ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFuse
|
|
|
7
7
|
class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
8
8
|
@staticmethod
|
|
9
9
|
def preference_loss_fn(
|
|
10
|
-
|
|
10
|
+
log_prob_chunk,
|
|
11
11
|
preference_labels_chunk,
|
|
12
12
|
full_target,
|
|
13
|
-
|
|
13
|
+
ref_log_prob_chunk=None,
|
|
14
14
|
beta=0.1,
|
|
15
15
|
kl=None,
|
|
16
16
|
):
|
|
@@ -43,30 +43,34 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
43
43
|
3. Maintain reasonable distance from the reference model
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
|
-
|
|
47
|
-
|
|
46
|
+
log_prob_chunk: Log probabilities for the chunk (batch_size,)
|
|
47
|
+
preference_labels_chunk: Preference labels for the chunk (batch_size,)
|
|
48
48
|
full_target: Non chunked full target tensor
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
beta: Weight for the direct preference loss
|
|
49
|
+
ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
|
|
50
|
+
beta: Weight for the KTO loss
|
|
52
51
|
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
|
53
52
|
Returns:
|
|
54
|
-
Tuple of (loss, chosen_rewards, rejected_rewards):
|
|
55
53
|
- loss: The KTO loss value
|
|
56
|
-
- chosen_rewards: Reward signals for chosen responses (detached)
|
|
57
|
-
- rejected_rewards: Reward signals for rejected responses (detached)
|
|
58
54
|
"""
|
|
59
|
-
|
|
55
|
+
if ref_log_prob_chunk is not None:
|
|
56
|
+
logratios_chunk = log_prob_chunk - ref_log_prob_chunk
|
|
57
|
+
else:
|
|
58
|
+
logratios_chunk = log_prob_chunk
|
|
60
59
|
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
|
61
60
|
if kl is not None:
|
|
62
61
|
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
|
63
62
|
else:
|
|
64
63
|
losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
|
|
65
64
|
|
|
66
|
-
|
|
65
|
+
rewards = beta * logratios_chunk
|
|
66
|
+
chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
|
|
67
|
+
rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
67
68
|
|
|
68
|
-
|
|
69
|
+
return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
69
72
|
def forward(
|
|
73
|
+
cls,
|
|
70
74
|
ctx,
|
|
71
75
|
_input,
|
|
72
76
|
weight,
|
|
@@ -81,15 +85,38 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
81
85
|
beta=0.1,
|
|
82
86
|
compiled=True,
|
|
83
87
|
use_ref_model=True,
|
|
88
|
+
average_log_prob=False,
|
|
89
|
+
chunk_size=1,
|
|
84
90
|
):
|
|
85
|
-
|
|
91
|
+
"""
|
|
92
|
+
Fused linear layer with KTO loss.
|
|
93
|
+
Args:
|
|
94
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
95
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
96
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
97
|
+
preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
|
|
98
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
99
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
100
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
101
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
102
|
+
kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
|
|
103
|
+
ignore_index (int): Index to ignore in loss computation
|
|
104
|
+
beta (float): Temperature parameter for the KTO loss
|
|
105
|
+
compiled (bool): Whether to use torch compile
|
|
106
|
+
use_ref_model (bool): Whether to use a reference model
|
|
107
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
108
|
+
chunk_size (int): Size of chunks for processing
|
|
109
|
+
Returns:
|
|
110
|
+
torch.Tensor: Computed loss
|
|
111
|
+
"""
|
|
112
|
+
return super().forward(
|
|
113
|
+
cls=cls,
|
|
86
114
|
ctx=ctx,
|
|
87
115
|
_input=_input,
|
|
88
116
|
weight=weight,
|
|
89
117
|
target=target,
|
|
90
118
|
preference_labels=preference_labels,
|
|
91
119
|
bias=bias,
|
|
92
|
-
loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
|
|
93
120
|
ignore_index=ignore_index,
|
|
94
121
|
beta=beta,
|
|
95
122
|
compiled=compiled,
|
|
@@ -97,7 +124,9 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
97
124
|
ref_input=ref_input,
|
|
98
125
|
ref_weight=ref_weight,
|
|
99
126
|
ref_bias=ref_bias,
|
|
127
|
+
average_log_prob=average_log_prob,
|
|
100
128
|
kl=kl,
|
|
129
|
+
chunk_size=chunk_size,
|
|
101
130
|
)
|
|
102
131
|
|
|
103
132
|
@staticmethod
|
|
@@ -115,6 +144,7 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
115
144
|
None,
|
|
116
145
|
None,
|
|
117
146
|
None,
|
|
147
|
+
None,
|
|
118
148
|
)
|
|
119
149
|
|
|
120
150
|
|
|
@@ -129,6 +159,8 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
129
159
|
beta: float = 0.1,
|
|
130
160
|
compiled: bool = True,
|
|
131
161
|
use_ref_model: bool = False,
|
|
162
|
+
average_log_prob: bool = False,
|
|
163
|
+
chunk_size: int = 1,
|
|
132
164
|
):
|
|
133
165
|
"""
|
|
134
166
|
Args:
|
|
@@ -136,12 +168,16 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
136
168
|
beta (float): Temperature parameter for the KTO loss
|
|
137
169
|
compiled (bool): Whether to use compiled operations
|
|
138
170
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
171
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
172
|
+
chunk_size (int): Size of chunks for processing
|
|
139
173
|
"""
|
|
140
174
|
super().__init__()
|
|
141
175
|
self.ignore_index = ignore_index
|
|
142
176
|
self.beta = beta
|
|
143
177
|
self.compiled = compiled
|
|
144
178
|
self.use_ref_model = use_ref_model
|
|
179
|
+
self.average_log_prob = average_log_prob
|
|
180
|
+
self.chunk_size = chunk_size
|
|
145
181
|
|
|
146
182
|
def forward(
|
|
147
183
|
self,
|
|
@@ -169,4 +205,6 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
169
205
|
self.beta,
|
|
170
206
|
self.compiled,
|
|
171
207
|
self.use_ref_model,
|
|
208
|
+
self.average_log_prob,
|
|
209
|
+
self.chunk_size,
|
|
172
210
|
)
|
|
@@ -42,8 +42,9 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
42
42
|
|
|
43
43
|
return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
|
|
44
44
|
|
|
45
|
-
@
|
|
45
|
+
@classmethod
|
|
46
46
|
def forward(
|
|
47
|
+
cls,
|
|
47
48
|
ctx,
|
|
48
49
|
_input,
|
|
49
50
|
weight,
|
|
@@ -54,25 +55,43 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
54
55
|
compute_nll_loss=True,
|
|
55
56
|
nll_target=None,
|
|
56
57
|
compiled=True,
|
|
58
|
+
chunk_size=1,
|
|
57
59
|
):
|
|
58
|
-
|
|
60
|
+
"""
|
|
61
|
+
Fused linear layer with ORPO loss.
|
|
62
|
+
Args:
|
|
63
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
64
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
65
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
66
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
67
|
+
ignore_index (int): Index to ignore in loss computation
|
|
68
|
+
beta (float): Weight for the odds ratio loss
|
|
69
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
70
|
+
nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
|
|
71
|
+
compiled (bool): Whether to use torch compile
|
|
72
|
+
chunk_size (int): Size of chunks for processing
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: Computed loss
|
|
75
|
+
"""
|
|
76
|
+
return super().forward(
|
|
77
|
+
cls=cls,
|
|
59
78
|
ctx=ctx,
|
|
60
79
|
_input=_input,
|
|
61
80
|
weight=weight,
|
|
62
81
|
target=target,
|
|
63
82
|
bias=bias,
|
|
64
|
-
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
|
|
65
83
|
ignore_index=ignore_index,
|
|
66
84
|
beta=beta,
|
|
67
85
|
compute_nll_loss=compute_nll_loss,
|
|
68
86
|
nll_target=nll_target,
|
|
69
87
|
compiled=compiled,
|
|
88
|
+
chunk_size=chunk_size,
|
|
70
89
|
)
|
|
71
90
|
|
|
72
91
|
@staticmethod
|
|
73
92
|
def backward(ctx, *grad_output):
|
|
74
93
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
75
|
-
return *grads, None, None, None, None, None
|
|
94
|
+
return *grads, None, None, None, None, None, None
|
|
76
95
|
|
|
77
96
|
|
|
78
97
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
@@ -86,19 +105,31 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
86
105
|
beta: float = 0.1,
|
|
87
106
|
compute_nll_loss: bool = True,
|
|
88
107
|
compiled: bool = True,
|
|
108
|
+
chunk_size: int = 1,
|
|
89
109
|
):
|
|
90
110
|
"""
|
|
91
111
|
Args:
|
|
92
112
|
ignore_index (int): Index to ignore in the loss.
|
|
93
113
|
beta (float): Weight for the odds ratio loss.
|
|
114
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
115
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
116
|
+
chunk_size (int): Size of chunks for processing.
|
|
94
117
|
"""
|
|
95
118
|
super().__init__()
|
|
96
119
|
self.ignore_index = ignore_index
|
|
97
120
|
self.beta = beta
|
|
98
121
|
self.compute_nll_loss = compute_nll_loss
|
|
99
122
|
self.compiled = compiled
|
|
123
|
+
self.chunk_size = chunk_size
|
|
100
124
|
|
|
101
|
-
def forward(
|
|
125
|
+
def forward(
|
|
126
|
+
self,
|
|
127
|
+
lin_weight,
|
|
128
|
+
_input,
|
|
129
|
+
target,
|
|
130
|
+
bias=None,
|
|
131
|
+
nll_target=None,
|
|
132
|
+
):
|
|
102
133
|
return LigerFusedLinearORPOFunction.apply(
|
|
103
134
|
_input,
|
|
104
135
|
lin_weight,
|
|
@@ -109,4 +140,5 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
109
140
|
self.compute_nll_loss,
|
|
110
141
|
nll_target,
|
|
111
142
|
self.compiled,
|
|
143
|
+
self.chunk_size,
|
|
112
144
|
)
|
|
@@ -47,8 +47,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
47
47
|
|
|
48
48
|
return loss, chosen_rewards, rejected_rewards
|
|
49
49
|
|
|
50
|
-
@
|
|
50
|
+
@classmethod
|
|
51
51
|
def forward(
|
|
52
|
+
cls,
|
|
52
53
|
ctx,
|
|
53
54
|
_input,
|
|
54
55
|
weight,
|
|
@@ -61,27 +62,47 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
61
62
|
compute_nll_loss=False,
|
|
62
63
|
compiled=True,
|
|
63
64
|
gamma=0.5,
|
|
65
|
+
chunk_size=1,
|
|
64
66
|
):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
67
|
+
"""
|
|
68
|
+
Fused linear layer with SimPO loss.
|
|
69
|
+
Args:
|
|
70
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
71
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
72
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
73
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
74
|
+
ignore_index (int): Index to ignore in loss computation
|
|
75
|
+
beta (float): Weight for the odds ratio loss
|
|
76
|
+
alpha (float): Weight for the alpha parameter
|
|
77
|
+
label_smoothing (float): Label smoothing factor
|
|
78
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
79
|
+
compiled (bool): Whether to use torch compile
|
|
80
|
+
gamma (float): Weight for the gamma parameter
|
|
81
|
+
chunk_size (int): Size of chunks for processing
|
|
82
|
+
Returns:
|
|
83
|
+
torch.Tensor: Computed loss
|
|
84
|
+
"""
|
|
85
|
+
return super().forward(
|
|
86
|
+
cls=cls,
|
|
87
|
+
ctx=ctx,
|
|
88
|
+
_input=_input,
|
|
89
|
+
weight=weight,
|
|
90
|
+
target=target,
|
|
91
|
+
bias=bias,
|
|
73
92
|
ignore_index=ignore_index,
|
|
74
93
|
alpha=alpha,
|
|
75
94
|
beta=beta,
|
|
76
95
|
label_smoothing=label_smoothing,
|
|
96
|
+
compute_nll_loss=compute_nll_loss,
|
|
77
97
|
compiled=compiled,
|
|
78
98
|
gamma=gamma,
|
|
99
|
+
chunk_size=chunk_size,
|
|
79
100
|
)
|
|
80
101
|
|
|
81
102
|
@staticmethod
|
|
82
103
|
def backward(ctx, *grad_output):
|
|
83
104
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
84
|
-
return *grads, None, None, None, None, None, None, None
|
|
105
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
85
106
|
|
|
86
107
|
|
|
87
108
|
class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
@@ -98,11 +119,18 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
98
119
|
compute_nll_loss: bool = True,
|
|
99
120
|
compiled: bool = True,
|
|
100
121
|
gamma: float = 0.5,
|
|
122
|
+
chunk_size: int = 1,
|
|
101
123
|
):
|
|
102
124
|
"""
|
|
103
125
|
Args:
|
|
104
126
|
ignore_index (int): Index to ignore in the loss.
|
|
105
127
|
beta (float): Weight for the odds ratio loss.
|
|
128
|
+
alpha (float): Weight for the alpha parameter.
|
|
129
|
+
label_smoothing (float): Label smoothing factor.
|
|
130
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
131
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
132
|
+
gamma (float): Weight for the gamma parameter.
|
|
133
|
+
chunk_size (int): Size of chunks for processing.
|
|
106
134
|
"""
|
|
107
135
|
super().__init__()
|
|
108
136
|
self.ignore_index = ignore_index
|
|
@@ -112,8 +140,15 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
112
140
|
self.compute_nll_loss = compute_nll_loss
|
|
113
141
|
self.compiled = compiled
|
|
114
142
|
self.gamma = gamma
|
|
143
|
+
self.chunk_size = chunk_size
|
|
115
144
|
|
|
116
|
-
def forward(
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
lin_weight,
|
|
148
|
+
_input,
|
|
149
|
+
target,
|
|
150
|
+
bias=None,
|
|
151
|
+
):
|
|
117
152
|
return LigerFusedLinearSimPOFunction.apply(
|
|
118
153
|
_input,
|
|
119
154
|
lin_weight,
|
|
@@ -126,4 +161,5 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
126
161
|
self.compute_nll_loss,
|
|
127
162
|
self.compiled,
|
|
128
163
|
self.gamma,
|
|
164
|
+
self.chunk_size,
|
|
129
165
|
)
|
|
@@ -285,13 +285,17 @@ def cross_entropy_forward(
|
|
|
285
285
|
|
|
286
286
|
target_mask = target != ignore_index
|
|
287
287
|
n_non_ignore = target_mask.sum().item()
|
|
288
|
+
assert (target * target_mask).max() < _input.shape[-1], (
|
|
289
|
+
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
|
|
290
|
+
)
|
|
291
|
+
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
|
|
288
292
|
sum_non_ignore_weight = n_non_ignore
|
|
289
293
|
weight_sum = 0.0
|
|
290
294
|
if weight is not None:
|
|
291
295
|
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
|
292
|
-
assert torch.is_floating_point(
|
|
293
|
-
weight
|
|
294
|
-
)
|
|
296
|
+
assert torch.is_floating_point(weight), (
|
|
297
|
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
|
298
|
+
)
|
|
295
299
|
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
296
300
|
weight_sum = weight.sum().item()
|
|
297
301
|
# ensure weight is contiguous
|
|
@@ -58,9 +58,9 @@ def fused_linear_cross_entropy_forward(
|
|
|
58
58
|
ce_weight_sum = 0.0
|
|
59
59
|
if ce_weight is not None:
|
|
60
60
|
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
|
|
61
|
-
assert torch.is_floating_point(
|
|
62
|
-
ce_weight
|
|
63
|
-
)
|
|
61
|
+
assert torch.is_floating_point(ce_weight), (
|
|
62
|
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
|
|
63
|
+
)
|
|
64
64
|
total_sum_non_ignore_ce_weight = (
|
|
65
65
|
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
66
66
|
)
|
|
@@ -195,9 +195,9 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
195
195
|
"""
|
|
196
196
|
has_label = False
|
|
197
197
|
if shift_labels is not None:
|
|
198
|
-
assert shift_labels.shape == (
|
|
199
|
-
|
|
200
|
-
)
|
|
198
|
+
assert shift_labels.shape == (teacher_input.shape[0],), (
|
|
199
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
200
|
+
)
|
|
201
201
|
shift_labels = shift_labels.contiguous()
|
|
202
202
|
has_label = True
|
|
203
203
|
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -157,9 +157,9 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
|
157
157
|
"""
|
|
158
158
|
has_label = False
|
|
159
159
|
if shift_labels is not None:
|
|
160
|
-
assert shift_labels.shape == (
|
|
161
|
-
|
|
162
|
-
)
|
|
160
|
+
assert shift_labels.shape == (_input.shape[0],), (
|
|
161
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
162
|
+
)
|
|
163
163
|
shift_labels = shift_labels.contiguous()
|
|
164
164
|
has_label = True
|
|
165
165
|
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -57,13 +57,14 @@ def _layer_norm_forward_kernel(
|
|
|
57
57
|
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
|
|
58
58
|
|
|
59
59
|
mean = tl.sum(X_row, axis=0) / n_cols
|
|
60
|
-
|
|
60
|
+
Xmm = tl.where(mask, X_row - mean, 0)
|
|
61
|
+
var = tl.sum(Xmm * Xmm, axis=0) / n_cols
|
|
61
62
|
rstd = rsqrt(var + eps)
|
|
62
63
|
|
|
63
64
|
tl.store(Mean_ptr, mean)
|
|
64
65
|
tl.store(RSTD_ptr, rstd)
|
|
65
66
|
|
|
66
|
-
Y_row =
|
|
67
|
+
Y_row = Xmm * rstd * W_row + B_row
|
|
67
68
|
|
|
68
69
|
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
69
70
|
|
|
@@ -147,9 +148,11 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
147
148
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
148
149
|
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
149
150
|
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
151
|
+
if X.shape[1] != W.shape[0]:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
154
|
+
f"must match weight size (W.shape[0]={W.shape[0]})"
|
|
155
|
+
)
|
|
153
156
|
|
|
154
157
|
_layer_norm_forward_kernel[(n_rows,)](
|
|
155
158
|
Y,
|
|
@@ -190,11 +193,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
190
193
|
|
|
191
194
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
192
195
|
if n_cols > BLOCK_SIZE:
|
|
193
|
-
raise RuntimeError(
|
|
196
|
+
raise RuntimeError(
|
|
197
|
+
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
198
|
+
)
|
|
194
199
|
|
|
195
200
|
rows_per_program = math.ceil(n_rows / sm_count)
|
|
196
201
|
grid = (sm_count,)
|
|
197
|
-
triton_dtype =
|
|
202
|
+
triton_dtype = (
|
|
203
|
+
tl.float32
|
|
204
|
+
if X.dtype == torch.float32
|
|
205
|
+
else tl.bfloat16
|
|
206
|
+
if X.dtype == torch.bfloat16
|
|
207
|
+
else tl.float16
|
|
208
|
+
if X.dtype == torch.float16
|
|
209
|
+
else tl.float32 # fallback to float32 for other types
|
|
210
|
+
)
|
|
198
211
|
_layer_norm_backward_kernel[grid](
|
|
199
212
|
X,
|
|
200
213
|
W,
|
liger_kernel/ops/tvd.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
MAX_FUSED_SIZE = 65536 // 4
|
|
11
|
+
|
|
12
|
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
13
|
+
|
|
14
|
+
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
|
15
|
+
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
|
16
|
+
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
|
17
|
+
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
|
18
|
+
|
|
19
|
+
_str_to_reduction_mode = {
|
|
20
|
+
"none": _REDUCTION_MODE_NONE.value,
|
|
21
|
+
"sum": _REDUCTION_MODE_SUM.value,
|
|
22
|
+
"mean": _REDUCTION_MODE_MEAN.value,
|
|
23
|
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_num_warps(BLOCK_SIZE):
|
|
28
|
+
num_warps = 4
|
|
29
|
+
if BLOCK_SIZE >= 32768:
|
|
30
|
+
num_warps = 32
|
|
31
|
+
elif BLOCK_SIZE >= 8192:
|
|
32
|
+
num_warps = 16
|
|
33
|
+
elif BLOCK_SIZE >= 2048:
|
|
34
|
+
num_warps = 8
|
|
35
|
+
|
|
36
|
+
return num_warps
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@triton.jit
|
|
40
|
+
def _tv_distance_kernel(
|
|
41
|
+
p_ptr,
|
|
42
|
+
p_stride,
|
|
43
|
+
q_ptr,
|
|
44
|
+
q_stride,
|
|
45
|
+
loss_ptr,
|
|
46
|
+
loss_stride,
|
|
47
|
+
grads_ptr,
|
|
48
|
+
grads_stride,
|
|
49
|
+
label_ptr,
|
|
50
|
+
ignore_index: tl.constexpr,
|
|
51
|
+
n_cols,
|
|
52
|
+
BLOCK_SIZE: tl.constexpr,
|
|
53
|
+
HAS_LABEL: tl.constexpr,
|
|
54
|
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
55
|
+
):
|
|
56
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
57
|
+
p_ptr += pid * p_stride
|
|
58
|
+
q_ptr += pid * q_stride
|
|
59
|
+
loss_ptr += pid * loss_stride
|
|
60
|
+
grads_ptr += pid * grads_stride
|
|
61
|
+
label_ptr += pid
|
|
62
|
+
|
|
63
|
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
64
|
+
|
|
65
|
+
if HAS_LABEL:
|
|
66
|
+
label = tl.load(label_ptr)
|
|
67
|
+
if label == ignore_index:
|
|
68
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
69
|
+
offsets = i + base_offsets
|
|
70
|
+
mask = offsets < n_cols
|
|
71
|
+
tl.store(grads_ptr + offsets, 0.0, mask=mask)
|
|
72
|
+
if reduction == _REDUCTION_MODE_NONE:
|
|
73
|
+
tl.store(loss_ptr + offsets, 0.0, mask=mask)
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
loss_sum = 0.0
|
|
77
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
78
|
+
offsets = i + base_offsets
|
|
79
|
+
mask = offsets < n_cols
|
|
80
|
+
|
|
81
|
+
p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
|
|
82
|
+
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
|
|
83
|
+
|
|
84
|
+
# TVD(P || Q) = 0.5 * |P - Q|
|
|
85
|
+
tv_loss = 0.5 * tl.abs(p - q)
|
|
86
|
+
|
|
87
|
+
grad_res = tl.where(p > q, 0.5, -0.5)
|
|
88
|
+
|
|
89
|
+
tl.store(grads_ptr + offsets, grad_res, mask=mask)
|
|
90
|
+
|
|
91
|
+
if reduction == _REDUCTION_MODE_NONE:
|
|
92
|
+
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
|
|
93
|
+
else:
|
|
94
|
+
loss_sum += tl.sum(tv_loss, axis=0)
|
|
95
|
+
|
|
96
|
+
if reduction != _REDUCTION_MODE_NONE:
|
|
97
|
+
tl.store(loss_ptr, loss_sum)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
|
101
|
+
BT, V = p.shape
|
|
102
|
+
|
|
103
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
104
|
+
num_warps = get_num_warps(BLOCK_SIZE)
|
|
105
|
+
|
|
106
|
+
grid = (BT,)
|
|
107
|
+
|
|
108
|
+
reduction = _str_to_reduction_mode[reduction]
|
|
109
|
+
|
|
110
|
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
|
111
|
+
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
|
|
112
|
+
grads = torch.empty_like(p)
|
|
113
|
+
|
|
114
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
|
115
|
+
|
|
116
|
+
_tv_distance_kernel[grid](
|
|
117
|
+
p,
|
|
118
|
+
p.stride(0),
|
|
119
|
+
q,
|
|
120
|
+
q.stride(0),
|
|
121
|
+
output_tensor,
|
|
122
|
+
output_tensor.stride(0),
|
|
123
|
+
grads,
|
|
124
|
+
grads.stride(0),
|
|
125
|
+
shift_labels if has_label else torch.empty(1, device=p.device),
|
|
126
|
+
ignore_index,
|
|
127
|
+
V,
|
|
128
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
129
|
+
HAS_LABEL=has_label,
|
|
130
|
+
num_warps=num_warps,
|
|
131
|
+
reduction=reduction,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
|
135
|
+
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
|
|
136
|
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
|
137
|
+
return output_tensor.sum(dim=0), grads
|
|
138
|
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
|
139
|
+
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
|
140
|
+
else:
|
|
141
|
+
return output_tensor, grads
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def tvd_backward_triton(grad_output, grads):
|
|
145
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
|
146
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
147
|
+
return grads
|
|
148
|
+
|
|
149
|
+
return grads * grad_output
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class LigerTVDLossFunction(torch.autograd.Function):
|
|
153
|
+
"""
|
|
154
|
+
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
@ensure_contiguous
|
|
159
|
+
def forward(
|
|
160
|
+
ctx,
|
|
161
|
+
p: torch.Tensor,
|
|
162
|
+
q: torch.Tensor,
|
|
163
|
+
shift_labels: Optional[torch.Tensor] = None,
|
|
164
|
+
reduction: REDUCTION_LITERAL = "batchmean",
|
|
165
|
+
ignore_index: int = -100,
|
|
166
|
+
) -> torch.Tensor:
|
|
167
|
+
"""A forward pass for the Total Variation Distance Loss.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
ctx: Torch autograd context
|
|
171
|
+
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
|
|
172
|
+
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
|
|
173
|
+
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
|
|
174
|
+
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
|
|
175
|
+
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
torch.Tensor: The computed Total Variation Distance Loss.
|
|
179
|
+
"""
|
|
180
|
+
has_label = False
|
|
181
|
+
if shift_labels is not None:
|
|
182
|
+
assert shift_labels.shape == (p.shape[0],), (
|
|
183
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
184
|
+
)
|
|
185
|
+
shift_labels = shift_labels.contiguous()
|
|
186
|
+
has_label = True
|
|
187
|
+
|
|
188
|
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
|
189
|
+
ctx.save_for_backward(grads)
|
|
190
|
+
return loss
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
@ensure_contiguous
|
|
194
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
195
|
+
"""A backward pass for the Total Variation Distance Loss.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
ctx: Torch autograd context
|
|
199
|
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
|
|
203
|
+
"""
|
|
204
|
+
(grads,) = ctx.saved_tensors
|
|
205
|
+
grads = tvd_backward_triton(grad_output, grads)
|
|
206
|
+
|
|
207
|
+
return grads, None, None, None, None
|