liger-kernel 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +51 -11
- liger_kernel/chunked_loss/dpo_loss.py +30 -4
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +20 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- liger_kernel/chunked_loss/grpo_loss.py +137 -61
- liger_kernel/chunked_loss/jsd_loss.py +43 -13
- liger_kernel/chunked_loss/kto_loss.py +50 -12
- liger_kernel/chunked_loss/orpo_loss.py +37 -5
- liger_kernel/chunked_loss/simpo_loss.py +47 -11
- liger_kernel/ops/cross_entropy.py +7 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +30 -11
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/transformers/__init__.py +4 -0
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/model/gemma.py +8 -16
- liger_kernel/transformers/model/gemma2.py +7 -16
- liger_kernel/transformers/model/llama.py +8 -15
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +57 -0
- liger_kernel/transformers/model/mistral.py +9 -10
- liger_kernel/transformers/model/mixtral.py +8 -15
- liger_kernel/transformers/model/mllama.py +8 -15
- liger_kernel/transformers/model/olmo2.py +8 -16
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +8 -15
- liger_kernel/transformers/model/qwen2.py +8 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +204 -0
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +286 -12
- liger_kernel/utils.py +1 -3
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +11 -7
- liger_kernel-0.5.6.dist-info/RECORD +80 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -213
- liger_kernel-0.5.4.dist-info/RECORD +0 -74
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
|
@@ -7,10 +7,10 @@ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFuse
|
|
|
7
7
|
class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
8
8
|
@staticmethod
|
|
9
9
|
def preference_loss_fn(
|
|
10
|
-
|
|
10
|
+
log_prob_chunk,
|
|
11
11
|
preference_labels_chunk,
|
|
12
12
|
full_target,
|
|
13
|
-
|
|
13
|
+
ref_log_prob_chunk=None,
|
|
14
14
|
beta=0.1,
|
|
15
15
|
kl=None,
|
|
16
16
|
):
|
|
@@ -43,30 +43,34 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
43
43
|
3. Maintain reasonable distance from the reference model
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
|
-
|
|
46
|
+
log_prob_chunk: Log probabilities for the chunk (batch_size,)
|
|
47
47
|
preference_labels_chunk: Preference labels for the chunk (batch_size,)
|
|
48
48
|
full_target: Non chunked full target tensor
|
|
49
|
-
|
|
49
|
+
ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
|
|
50
50
|
beta: Weight for the KTO loss
|
|
51
51
|
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
|
52
52
|
Returns:
|
|
53
53
|
- loss: The KTO loss value
|
|
54
54
|
"""
|
|
55
|
-
if
|
|
56
|
-
logratios_chunk =
|
|
55
|
+
if ref_log_prob_chunk is not None:
|
|
56
|
+
logratios_chunk = log_prob_chunk - ref_log_prob_chunk
|
|
57
57
|
else:
|
|
58
|
-
logratios_chunk =
|
|
59
|
-
|
|
58
|
+
logratios_chunk = log_prob_chunk
|
|
60
59
|
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
|
61
60
|
if kl is not None:
|
|
62
61
|
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
|
63
62
|
else:
|
|
64
63
|
losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
|
|
65
64
|
|
|
66
|
-
|
|
65
|
+
rewards = beta * logratios_chunk
|
|
66
|
+
chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
|
|
67
|
+
rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
67
68
|
|
|
68
|
-
|
|
69
|
+
return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
69
72
|
def forward(
|
|
73
|
+
cls,
|
|
70
74
|
ctx,
|
|
71
75
|
_input,
|
|
72
76
|
weight,
|
|
@@ -81,15 +85,38 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
81
85
|
beta=0.1,
|
|
82
86
|
compiled=True,
|
|
83
87
|
use_ref_model=True,
|
|
88
|
+
average_log_prob=False,
|
|
89
|
+
chunk_size=1,
|
|
84
90
|
):
|
|
85
|
-
|
|
91
|
+
"""
|
|
92
|
+
Fused linear layer with KTO loss.
|
|
93
|
+
Args:
|
|
94
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
95
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
96
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
97
|
+
preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
|
|
98
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
99
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
100
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
101
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
102
|
+
kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
|
|
103
|
+
ignore_index (int): Index to ignore in loss computation
|
|
104
|
+
beta (float): Temperature parameter for the KTO loss
|
|
105
|
+
compiled (bool): Whether to use torch compile
|
|
106
|
+
use_ref_model (bool): Whether to use a reference model
|
|
107
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
108
|
+
chunk_size (int): Size of chunks for processing
|
|
109
|
+
Returns:
|
|
110
|
+
torch.Tensor: Computed loss
|
|
111
|
+
"""
|
|
112
|
+
return super().forward(
|
|
113
|
+
cls=cls,
|
|
86
114
|
ctx=ctx,
|
|
87
115
|
_input=_input,
|
|
88
116
|
weight=weight,
|
|
89
117
|
target=target,
|
|
90
118
|
preference_labels=preference_labels,
|
|
91
119
|
bias=bias,
|
|
92
|
-
loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
|
|
93
120
|
ignore_index=ignore_index,
|
|
94
121
|
beta=beta,
|
|
95
122
|
compiled=compiled,
|
|
@@ -97,7 +124,9 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
97
124
|
ref_input=ref_input,
|
|
98
125
|
ref_weight=ref_weight,
|
|
99
126
|
ref_bias=ref_bias,
|
|
127
|
+
average_log_prob=average_log_prob,
|
|
100
128
|
kl=kl,
|
|
129
|
+
chunk_size=chunk_size,
|
|
101
130
|
)
|
|
102
131
|
|
|
103
132
|
@staticmethod
|
|
@@ -115,6 +144,7 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
115
144
|
None,
|
|
116
145
|
None,
|
|
117
146
|
None,
|
|
147
|
+
None,
|
|
118
148
|
)
|
|
119
149
|
|
|
120
150
|
|
|
@@ -129,6 +159,8 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
129
159
|
beta: float = 0.1,
|
|
130
160
|
compiled: bool = True,
|
|
131
161
|
use_ref_model: bool = False,
|
|
162
|
+
average_log_prob: bool = False,
|
|
163
|
+
chunk_size: int = 1,
|
|
132
164
|
):
|
|
133
165
|
"""
|
|
134
166
|
Args:
|
|
@@ -136,12 +168,16 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
136
168
|
beta (float): Temperature parameter for the KTO loss
|
|
137
169
|
compiled (bool): Whether to use compiled operations
|
|
138
170
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
171
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
172
|
+
chunk_size (int): Size of chunks for processing
|
|
139
173
|
"""
|
|
140
174
|
super().__init__()
|
|
141
175
|
self.ignore_index = ignore_index
|
|
142
176
|
self.beta = beta
|
|
143
177
|
self.compiled = compiled
|
|
144
178
|
self.use_ref_model = use_ref_model
|
|
179
|
+
self.average_log_prob = average_log_prob
|
|
180
|
+
self.chunk_size = chunk_size
|
|
145
181
|
|
|
146
182
|
def forward(
|
|
147
183
|
self,
|
|
@@ -169,4 +205,6 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
169
205
|
self.beta,
|
|
170
206
|
self.compiled,
|
|
171
207
|
self.use_ref_model,
|
|
208
|
+
self.average_log_prob,
|
|
209
|
+
self.chunk_size,
|
|
172
210
|
)
|
|
@@ -42,8 +42,9 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
42
42
|
|
|
43
43
|
return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
|
|
44
44
|
|
|
45
|
-
@
|
|
45
|
+
@classmethod
|
|
46
46
|
def forward(
|
|
47
|
+
cls,
|
|
47
48
|
ctx,
|
|
48
49
|
_input,
|
|
49
50
|
weight,
|
|
@@ -54,25 +55,43 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
54
55
|
compute_nll_loss=True,
|
|
55
56
|
nll_target=None,
|
|
56
57
|
compiled=True,
|
|
58
|
+
chunk_size=1,
|
|
57
59
|
):
|
|
58
|
-
|
|
60
|
+
"""
|
|
61
|
+
Fused linear layer with ORPO loss.
|
|
62
|
+
Args:
|
|
63
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
64
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
65
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
66
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
67
|
+
ignore_index (int): Index to ignore in loss computation
|
|
68
|
+
beta (float): Weight for the odds ratio loss
|
|
69
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
70
|
+
nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
|
|
71
|
+
compiled (bool): Whether to use torch compile
|
|
72
|
+
chunk_size (int): Size of chunks for processing
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: Computed loss
|
|
75
|
+
"""
|
|
76
|
+
return super().forward(
|
|
77
|
+
cls=cls,
|
|
59
78
|
ctx=ctx,
|
|
60
79
|
_input=_input,
|
|
61
80
|
weight=weight,
|
|
62
81
|
target=target,
|
|
63
82
|
bias=bias,
|
|
64
|
-
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
|
|
65
83
|
ignore_index=ignore_index,
|
|
66
84
|
beta=beta,
|
|
67
85
|
compute_nll_loss=compute_nll_loss,
|
|
68
86
|
nll_target=nll_target,
|
|
69
87
|
compiled=compiled,
|
|
88
|
+
chunk_size=chunk_size,
|
|
70
89
|
)
|
|
71
90
|
|
|
72
91
|
@staticmethod
|
|
73
92
|
def backward(ctx, *grad_output):
|
|
74
93
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
75
|
-
return *grads, None, None, None, None, None
|
|
94
|
+
return *grads, None, None, None, None, None, None
|
|
76
95
|
|
|
77
96
|
|
|
78
97
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
@@ -86,19 +105,31 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
86
105
|
beta: float = 0.1,
|
|
87
106
|
compute_nll_loss: bool = True,
|
|
88
107
|
compiled: bool = True,
|
|
108
|
+
chunk_size: int = 1,
|
|
89
109
|
):
|
|
90
110
|
"""
|
|
91
111
|
Args:
|
|
92
112
|
ignore_index (int): Index to ignore in the loss.
|
|
93
113
|
beta (float): Weight for the odds ratio loss.
|
|
114
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
115
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
116
|
+
chunk_size (int): Size of chunks for processing.
|
|
94
117
|
"""
|
|
95
118
|
super().__init__()
|
|
96
119
|
self.ignore_index = ignore_index
|
|
97
120
|
self.beta = beta
|
|
98
121
|
self.compute_nll_loss = compute_nll_loss
|
|
99
122
|
self.compiled = compiled
|
|
123
|
+
self.chunk_size = chunk_size
|
|
100
124
|
|
|
101
|
-
def forward(
|
|
125
|
+
def forward(
|
|
126
|
+
self,
|
|
127
|
+
lin_weight,
|
|
128
|
+
_input,
|
|
129
|
+
target,
|
|
130
|
+
bias=None,
|
|
131
|
+
nll_target=None,
|
|
132
|
+
):
|
|
102
133
|
return LigerFusedLinearORPOFunction.apply(
|
|
103
134
|
_input,
|
|
104
135
|
lin_weight,
|
|
@@ -109,4 +140,5 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
109
140
|
self.compute_nll_loss,
|
|
110
141
|
nll_target,
|
|
111
142
|
self.compiled,
|
|
143
|
+
self.chunk_size,
|
|
112
144
|
)
|
|
@@ -47,8 +47,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
47
47
|
|
|
48
48
|
return loss, chosen_rewards, rejected_rewards
|
|
49
49
|
|
|
50
|
-
@
|
|
50
|
+
@classmethod
|
|
51
51
|
def forward(
|
|
52
|
+
cls,
|
|
52
53
|
ctx,
|
|
53
54
|
_input,
|
|
54
55
|
weight,
|
|
@@ -61,27 +62,47 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
61
62
|
compute_nll_loss=False,
|
|
62
63
|
compiled=True,
|
|
63
64
|
gamma=0.5,
|
|
65
|
+
chunk_size=1,
|
|
64
66
|
):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
67
|
+
"""
|
|
68
|
+
Fused linear layer with SimPO loss.
|
|
69
|
+
Args:
|
|
70
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
71
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
72
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
73
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
74
|
+
ignore_index (int): Index to ignore in loss computation
|
|
75
|
+
beta (float): Weight for the odds ratio loss
|
|
76
|
+
alpha (float): Weight for the alpha parameter
|
|
77
|
+
label_smoothing (float): Label smoothing factor
|
|
78
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
79
|
+
compiled (bool): Whether to use torch compile
|
|
80
|
+
gamma (float): Weight for the gamma parameter
|
|
81
|
+
chunk_size (int): Size of chunks for processing
|
|
82
|
+
Returns:
|
|
83
|
+
torch.Tensor: Computed loss
|
|
84
|
+
"""
|
|
85
|
+
return super().forward(
|
|
86
|
+
cls=cls,
|
|
87
|
+
ctx=ctx,
|
|
88
|
+
_input=_input,
|
|
89
|
+
weight=weight,
|
|
90
|
+
target=target,
|
|
91
|
+
bias=bias,
|
|
73
92
|
ignore_index=ignore_index,
|
|
74
93
|
alpha=alpha,
|
|
75
94
|
beta=beta,
|
|
76
95
|
label_smoothing=label_smoothing,
|
|
96
|
+
compute_nll_loss=compute_nll_loss,
|
|
77
97
|
compiled=compiled,
|
|
78
98
|
gamma=gamma,
|
|
99
|
+
chunk_size=chunk_size,
|
|
79
100
|
)
|
|
80
101
|
|
|
81
102
|
@staticmethod
|
|
82
103
|
def backward(ctx, *grad_output):
|
|
83
104
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
84
|
-
return *grads, None, None, None, None, None, None, None
|
|
105
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
85
106
|
|
|
86
107
|
|
|
87
108
|
class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
@@ -98,11 +119,18 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
98
119
|
compute_nll_loss: bool = True,
|
|
99
120
|
compiled: bool = True,
|
|
100
121
|
gamma: float = 0.5,
|
|
122
|
+
chunk_size: int = 1,
|
|
101
123
|
):
|
|
102
124
|
"""
|
|
103
125
|
Args:
|
|
104
126
|
ignore_index (int): Index to ignore in the loss.
|
|
105
127
|
beta (float): Weight for the odds ratio loss.
|
|
128
|
+
alpha (float): Weight for the alpha parameter.
|
|
129
|
+
label_smoothing (float): Label smoothing factor.
|
|
130
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
131
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
132
|
+
gamma (float): Weight for the gamma parameter.
|
|
133
|
+
chunk_size (int): Size of chunks for processing.
|
|
106
134
|
"""
|
|
107
135
|
super().__init__()
|
|
108
136
|
self.ignore_index = ignore_index
|
|
@@ -112,8 +140,15 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
112
140
|
self.compute_nll_loss = compute_nll_loss
|
|
113
141
|
self.compiled = compiled
|
|
114
142
|
self.gamma = gamma
|
|
143
|
+
self.chunk_size = chunk_size
|
|
115
144
|
|
|
116
|
-
def forward(
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
lin_weight,
|
|
148
|
+
_input,
|
|
149
|
+
target,
|
|
150
|
+
bias=None,
|
|
151
|
+
):
|
|
117
152
|
return LigerFusedLinearSimPOFunction.apply(
|
|
118
153
|
_input,
|
|
119
154
|
lin_weight,
|
|
@@ -126,4 +161,5 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
126
161
|
self.compute_nll_loss,
|
|
127
162
|
self.compiled,
|
|
128
163
|
self.gamma,
|
|
164
|
+
self.chunk_size,
|
|
129
165
|
)
|
|
@@ -9,6 +9,7 @@ import triton.language as tl
|
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
|
+
from liger_kernel.utils import infer_device
|
|
12
13
|
|
|
13
14
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
15
|
try:
|
|
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
|
|
|
59
60
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
60
61
|
loss_stride (int): The stride of the loss tensor.
|
|
61
62
|
n_cols (int): The number of columns in the input tensor.
|
|
62
|
-
n_non_ignore (
|
|
63
|
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
63
64
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
64
65
|
weight_sum (float): The sum of weight tensor.
|
|
65
66
|
ignore_index (int): The index to ignore in the target.
|
|
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
|
|
|
258
259
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
259
260
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
260
261
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
261
|
-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
262
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
|
262
263
|
|
|
263
264
|
|
|
264
265
|
def cross_entropy_forward(
|
|
@@ -285,6 +286,10 @@ def cross_entropy_forward(
|
|
|
285
286
|
|
|
286
287
|
target_mask = target != ignore_index
|
|
287
288
|
n_non_ignore = target_mask.sum().item()
|
|
289
|
+
assert (target * target_mask).max() < _input.shape[-1], (
|
|
290
|
+
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
|
|
291
|
+
)
|
|
292
|
+
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
|
|
288
293
|
sum_non_ignore_weight = n_non_ignore
|
|
289
294
|
weight_sum = 0.0
|
|
290
295
|
if weight is not None:
|
liger_kernel/ops/dyt.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.ops.utils import infer_device
|
|
11
|
+
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
try:
|
|
14
|
+
# typical import path with dispatch available
|
|
15
|
+
from triton.language.extra.libdevice import tanh
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
# for working with NGC containers
|
|
18
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
19
|
+
else:
|
|
20
|
+
from triton.language.math import tanh
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@triton.jit
|
|
24
|
+
def _dyt_fwd_kernel(
|
|
25
|
+
x_ptr,
|
|
26
|
+
x_row_stride,
|
|
27
|
+
alpha_ptr,
|
|
28
|
+
gamma_ptr,
|
|
29
|
+
beta_ptr,
|
|
30
|
+
y_ptr,
|
|
31
|
+
y_row_stride,
|
|
32
|
+
n_cols,
|
|
33
|
+
BLOCK_SIZE: tl.constexpr,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Reference:
|
|
37
|
+
https://arxiv.org/abs/2503.10622
|
|
38
|
+
|
|
39
|
+
Shapes:
|
|
40
|
+
- x: (BT, C)
|
|
41
|
+
- alpha: (1)
|
|
42
|
+
- gamma: (C)
|
|
43
|
+
- beta: (C)
|
|
44
|
+
"""
|
|
45
|
+
row_idx = tl.program_id(0)
|
|
46
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
+
mask = offsets < n_cols
|
|
48
|
+
|
|
49
|
+
x_ptr += row_idx * x_row_stride
|
|
50
|
+
y_ptr += row_idx * y_row_stride
|
|
51
|
+
|
|
52
|
+
alpha = tl.load(alpha_ptr)
|
|
53
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
|
54
|
+
beta = tl.load(beta_ptr + offsets, mask=mask)
|
|
55
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
56
|
+
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
|
57
|
+
tl.store(y_ptr + offsets, y, mask=mask)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@triton.jit
|
|
61
|
+
def _dyt_bwd_kernel(
|
|
62
|
+
x_ptr,
|
|
63
|
+
x_row_stride,
|
|
64
|
+
dy_ptr,
|
|
65
|
+
dy_row_stride,
|
|
66
|
+
dx_ptr,
|
|
67
|
+
dx_row_stride,
|
|
68
|
+
alpha_ptr,
|
|
69
|
+
dalpha_ptr,
|
|
70
|
+
gamma_ptr,
|
|
71
|
+
dgamma_ptr,
|
|
72
|
+
dgamma_row_stride,
|
|
73
|
+
n_cols,
|
|
74
|
+
n_rows,
|
|
75
|
+
ROWS_PER_PROGRAM: tl.constexpr,
|
|
76
|
+
BLOCK_SIZE: tl.constexpr,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Reference:
|
|
80
|
+
https://arxiv.org/abs/2503.10622
|
|
81
|
+
|
|
82
|
+
Shapes:
|
|
83
|
+
- x: (BT, C)
|
|
84
|
+
- alpha: (1)
|
|
85
|
+
- gamma: (C)
|
|
86
|
+
- dx: (BT, C)
|
|
87
|
+
- dy: (BT, C)
|
|
88
|
+
- dgamma: (sm_count, C)
|
|
89
|
+
- dalpha: (sm_count,)
|
|
90
|
+
"""
|
|
91
|
+
# d(gamma * tanh(alpha * x) + beta) / dx
|
|
92
|
+
# = gamma * (1 - tanh^2(alpha * x)) * alpha
|
|
93
|
+
# d(gamma * tanh(alpha * x) + beta) / dalpha
|
|
94
|
+
# = gamma * (1 - tanh^2(alpha * x)) * x
|
|
95
|
+
# d(gamma * tanh(alpha * x) + beta) / dgamma
|
|
96
|
+
# = tanh(alpha * x)
|
|
97
|
+
# d(gamma * tanh(alpha * x)) / dbeta = 1
|
|
98
|
+
pid = tl.program_id(0)
|
|
99
|
+
|
|
100
|
+
row_start = pid * ROWS_PER_PROGRAM
|
|
101
|
+
row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
|
|
102
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
103
|
+
mask = offsets < n_cols
|
|
104
|
+
|
|
105
|
+
dalpha = 0.0
|
|
106
|
+
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
107
|
+
|
|
108
|
+
x_ptr += row_start * x_row_stride
|
|
109
|
+
dx_ptr += row_start * dx_row_stride
|
|
110
|
+
dy_ptr += row_start * dy_row_stride
|
|
111
|
+
alpha = tl.load(alpha_ptr)
|
|
112
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
|
113
|
+
|
|
114
|
+
for _ in tl.range(row_start, row_end):
|
|
115
|
+
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
|
116
|
+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
|
117
|
+
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
|
118
|
+
sech2_ax = 1 - tanh_ax * tanh_ax
|
|
119
|
+
|
|
120
|
+
dx = dy * gamma * sech2_ax * alpha
|
|
121
|
+
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
|
122
|
+
dgamma += dy * tanh_ax
|
|
123
|
+
tl.store(dx_ptr + offsets, dx, mask=mask)
|
|
124
|
+
|
|
125
|
+
dy_ptr += dy_row_stride
|
|
126
|
+
x_ptr += x_row_stride
|
|
127
|
+
dx_ptr += dx_row_stride
|
|
128
|
+
|
|
129
|
+
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
|
130
|
+
tl.store(dalpha_ptr + pid, dalpha)
|
|
131
|
+
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def liger_dyt_fwd(x, alpha, gamma, beta):
|
|
136
|
+
shape = x.shape
|
|
137
|
+
dim = shape[-1]
|
|
138
|
+
x = x.view(-1, dim)
|
|
139
|
+
n_rows, n_cols = x.shape
|
|
140
|
+
y = torch.empty_like(x)
|
|
141
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
142
|
+
_dyt_fwd_kernel[(n_rows,)](
|
|
143
|
+
x_ptr=x,
|
|
144
|
+
alpha_ptr=alpha,
|
|
145
|
+
gamma_ptr=gamma,
|
|
146
|
+
beta_ptr=beta,
|
|
147
|
+
y_ptr=y,
|
|
148
|
+
x_row_stride=x.stride(0),
|
|
149
|
+
y_row_stride=y.stride(0),
|
|
150
|
+
n_cols=n_cols,
|
|
151
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
152
|
+
num_warps=num_warps,
|
|
153
|
+
)
|
|
154
|
+
return y.view(*shape)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def liger_dyt_bwd(dy, x, alpha, gamma):
|
|
158
|
+
shape = dy.shape
|
|
159
|
+
dtype = x.dtype
|
|
160
|
+
dim = shape[-1]
|
|
161
|
+
dy = dy.view(-1, dim)
|
|
162
|
+
x = x.view(-1, dim)
|
|
163
|
+
n_rows, n_cols = dy.shape
|
|
164
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
165
|
+
sm_count = 1
|
|
166
|
+
device = infer_device()
|
|
167
|
+
if device == "cuda":
|
|
168
|
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
169
|
+
elif device == "xpu":
|
|
170
|
+
sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
171
|
+
if n_cols > BLOCK_SIZE:
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
dx = torch.empty_like(x, dtype=torch.float32)
|
|
177
|
+
_dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
|
|
178
|
+
_dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
|
|
179
|
+
|
|
180
|
+
grid = (sm_count,)
|
|
181
|
+
rows_per_program = triton.cdiv(n_rows, sm_count)
|
|
182
|
+
_dyt_bwd_kernel[grid](
|
|
183
|
+
x_ptr=x,
|
|
184
|
+
x_row_stride=x.stride(0),
|
|
185
|
+
dy_ptr=dy,
|
|
186
|
+
dy_row_stride=dy.stride(0),
|
|
187
|
+
dx_ptr=dx,
|
|
188
|
+
dx_row_stride=dx.stride(0),
|
|
189
|
+
alpha_ptr=alpha,
|
|
190
|
+
dalpha_ptr=_dalpha,
|
|
191
|
+
gamma_ptr=gamma,
|
|
192
|
+
dgamma_ptr=_dgamma,
|
|
193
|
+
dgamma_row_stride=_dgamma.stride(0),
|
|
194
|
+
n_cols=n_cols,
|
|
195
|
+
n_rows=n_rows,
|
|
196
|
+
ROWS_PER_PROGRAM=rows_per_program,
|
|
197
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
+
num_warps=num_warps,
|
|
199
|
+
)
|
|
200
|
+
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
|
201
|
+
dgamma = _dgamma.sum(dim=0).to(dtype)
|
|
202
|
+
dbeta = dy.sum(dim=0).to(dtype)
|
|
203
|
+
return dx.view(*shape), dalpha, dgamma, dbeta
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class LigerDyTFunction(torch.autograd.Function):
|
|
207
|
+
@staticmethod
|
|
208
|
+
@ensure_contiguous
|
|
209
|
+
def forward(ctx, x, alpha, gamma, beta):
|
|
210
|
+
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
|
211
|
+
ctx.save_for_backward(x, alpha, gamma)
|
|
212
|
+
return y
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
@ensure_contiguous
|
|
216
|
+
def backward(ctx, grad_output):
|
|
217
|
+
x, alpha, gamma = ctx.saved_tensors
|
|
218
|
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
|
219
|
+
grad_output,
|
|
220
|
+
x,
|
|
221
|
+
alpha,
|
|
222
|
+
gamma,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return (dx, dalpha, dgamma, dbeta)
|
|
@@ -8,11 +8,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
|
|
|
8
8
|
from liger_kernel.ops.utils import amp_custom_fwd
|
|
9
9
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
10
10
|
from liger_kernel.ops.utils import is_hip
|
|
11
|
+
from liger_kernel.utils import infer_device
|
|
11
12
|
|
|
12
13
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
13
14
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
14
15
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
15
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
16
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def fused_linear_jsd_forward(
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -51,24 +51,43 @@ def _jsd_kernel(
|
|
|
51
51
|
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
52
52
|
|
|
53
53
|
if beta == 0.0: # forward KL
|
|
54
|
-
|
|
54
|
+
Y_max = tl.max(Y, axis=0)
|
|
55
|
+
Y_shifted = Y - Y_max
|
|
56
|
+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
|
|
55
57
|
loss = Y_prob * (Y - X)
|
|
56
58
|
dX = -Y_prob
|
|
57
|
-
elif beta == 1.0:
|
|
58
|
-
|
|
59
|
+
elif beta == 1.0: # reverse KL
|
|
60
|
+
X_max = tl.max(X, axis=0)
|
|
61
|
+
X_shifted = X - X_max
|
|
62
|
+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
|
|
59
63
|
loss = X_prob * (X - Y)
|
|
60
64
|
dX = loss + X_prob
|
|
61
65
|
else:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
log_M = tl.log(M)
|
|
66
|
+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
|
|
67
|
+
X_shifted = X - max_val
|
|
68
|
+
Y_shifted = Y - max_val
|
|
66
69
|
|
|
67
|
-
|
|
68
|
-
|
|
70
|
+
# Pre-compute exp(max_val) since it's used twice
|
|
71
|
+
exp_max = tl.exp(max_val)
|
|
72
|
+
|
|
73
|
+
# Compute exp terms with compensation
|
|
74
|
+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
|
|
75
|
+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
|
|
76
|
+
|
|
77
|
+
# Pre-compute common terms
|
|
78
|
+
beta_P = beta * P
|
|
79
|
+
one_minus_beta_Q = (1 - beta) * Q
|
|
80
|
+
M = beta_P + one_minus_beta_Q
|
|
81
|
+
log_M = tl.log(M) # No need to compensate as M is already in original scale
|
|
82
|
+
|
|
83
|
+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
|
|
84
|
+
dX = one_minus_beta_Q * (X - log_M)
|
|
85
|
+
|
|
86
|
+
# Pre-compute scaling factor
|
|
87
|
+
scale = 1.0 / n_non_ignore
|
|
88
|
+
loss = loss * scale
|
|
89
|
+
dX = dX * scale
|
|
69
90
|
|
|
70
|
-
loss = loss / n_non_ignore
|
|
71
|
-
dX = dX / n_non_ignore
|
|
72
91
|
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
73
92
|
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
74
93
|
|