liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +8 -1
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/cross_entropy.py +4 -1
- liger_kernel/ops/dyt.py +113 -179
- liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/sparsemax.py +167 -0
- liger_kernel/transformers/__init__.py +11 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +8 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +8 -12
- liger_kernel/transformers/model/gemma2.py +8 -10
- liger_kernel/transformers/model/gemma3.py +3 -9
- liger_kernel/transformers/model/glm4.py +119 -0
- liger_kernel/transformers/model/llama.py +64 -15
- liger_kernel/transformers/model/llava.py +0 -8
- liger_kernel/transformers/model/mistral.py +8 -10
- liger_kernel/transformers/model/mixtral.py +8 -12
- liger_kernel/transformers/model/mllama.py +8 -11
- liger_kernel/transformers/model/olmo2.py +8 -10
- liger_kernel/transformers/model/paligemma.py +0 -8
- liger_kernel/transformers/model/phi3.py +8 -12
- liger_kernel/transformers/model/qwen2.py +8 -12
- liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
- liger_kernel/transformers/model/qwen2_vl.py +3 -7
- liger_kernel/transformers/model/qwen3.py +112 -0
- liger_kernel/transformers/model/qwen3_moe.py +128 -0
- liger_kernel/transformers/monkey_patch.py +243 -13
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def _selective_log_softmax_kernel(
|
|
8
|
+
LOGITS,
|
|
9
|
+
INPUT_IDS,
|
|
10
|
+
LOG_P,
|
|
11
|
+
MASK,
|
|
12
|
+
TEMPERATURE,
|
|
13
|
+
stride_input_ids_b,
|
|
14
|
+
L: tl.constexpr,
|
|
15
|
+
N: tl.constexpr,
|
|
16
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
17
|
+
):
|
|
18
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
19
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
20
|
+
|
|
21
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
22
|
+
INPUT_IDS += off_b * stride_input_ids_b + off_l
|
|
23
|
+
LOG_P += off_b * L + off_l
|
|
24
|
+
|
|
25
|
+
if MASK is not None:
|
|
26
|
+
MASK += off_b * stride_input_ids_b + off_l
|
|
27
|
+
not_skip = tl.load(MASK)
|
|
28
|
+
if not_skip == 0:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
m_i = float("-inf")
|
|
32
|
+
l_i = 0.0
|
|
33
|
+
for start in range(0, N, BLOCK_N):
|
|
34
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
35
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
36
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
37
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
38
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
39
|
+
m_i = new_m_i
|
|
40
|
+
lse = m_i + tl.log(l_i)
|
|
41
|
+
|
|
42
|
+
ids = tl.load(INPUT_IDS)
|
|
43
|
+
x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
|
|
44
|
+
logp = x - lse
|
|
45
|
+
tl.store(LOG_P, logp)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
|
|
51
|
+
assert logits.is_contiguous()
|
|
52
|
+
B, L_ADD_1, N = logits.shape
|
|
53
|
+
L = L_ADD_1 - 1
|
|
54
|
+
input_ids = input_ids[:, -L:]
|
|
55
|
+
if mask is not None:
|
|
56
|
+
mask = mask[:, -L:]
|
|
57
|
+
log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
|
|
58
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
|
|
59
|
+
_selective_log_softmax_kernel[(B, L)](
|
|
60
|
+
logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
|
|
61
|
+
)
|
|
62
|
+
return log_p
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
66
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
67
|
+
# for ns in [1, 2, 4]
|
|
68
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
69
|
+
# key=['N'])
|
|
70
|
+
@triton.jit
|
|
71
|
+
def _grpo_loss_fwd_kernel(
|
|
72
|
+
LOGITS,
|
|
73
|
+
OLD_LOGP,
|
|
74
|
+
REF_LOGP,
|
|
75
|
+
INPUT_IDS,
|
|
76
|
+
COMPLETION_MASK,
|
|
77
|
+
ADVANTAGES,
|
|
78
|
+
LOSS,
|
|
79
|
+
LSE,
|
|
80
|
+
KL,
|
|
81
|
+
IS_CLIPPED,
|
|
82
|
+
TEMPERATURE,
|
|
83
|
+
BETA: tl.constexpr,
|
|
84
|
+
EPS_LOW,
|
|
85
|
+
EPS_HIGH,
|
|
86
|
+
L: tl.constexpr,
|
|
87
|
+
N: tl.constexpr,
|
|
88
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
89
|
+
):
|
|
90
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
91
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
92
|
+
|
|
93
|
+
if COMPLETION_MASK is not None:
|
|
94
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
95
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
96
|
+
if not_skip == 0:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
100
|
+
INPUT_IDS += off_b * L + off_l
|
|
101
|
+
ADVANTAGES += off_b
|
|
102
|
+
LOSS += off_b * L + off_l
|
|
103
|
+
LSE += off_b * L + off_l
|
|
104
|
+
IS_CLIPPED += off_b * L + off_l
|
|
105
|
+
|
|
106
|
+
m_i = float("-inf")
|
|
107
|
+
l_i = 0.0
|
|
108
|
+
for start in range(0, N, BLOCK_N):
|
|
109
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
110
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
111
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
112
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
113
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
114
|
+
m_i = new_m_i
|
|
115
|
+
lse = m_i + tl.log(l_i)
|
|
116
|
+
|
|
117
|
+
idx = tl.load(INPUT_IDS)
|
|
118
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
119
|
+
logp = x - lse
|
|
120
|
+
if OLD_LOGP is None:
|
|
121
|
+
old_logp = logp
|
|
122
|
+
else:
|
|
123
|
+
OLD_LOGP += off_b * L + off_l
|
|
124
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
125
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
126
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
127
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
128
|
+
per_token_loss1 = coef_1 * advantage
|
|
129
|
+
per_token_loss2 = coef_2 * advantage
|
|
130
|
+
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
+
is_clipped = per_token_loss1 < per_token_loss2
|
|
132
|
+
|
|
133
|
+
if BETA != 0.0:
|
|
134
|
+
REF_LOGP += off_b * L + off_l
|
|
135
|
+
KL += off_b * L + off_l
|
|
136
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
137
|
+
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
|
|
138
|
+
per_token_loss += BETA * kl
|
|
139
|
+
tl.store(KL, kl)
|
|
140
|
+
|
|
141
|
+
tl.store(LOSS, per_token_loss)
|
|
142
|
+
tl.store(LSE, lse)
|
|
143
|
+
tl.store(IS_CLIPPED, is_clipped)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
147
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
148
|
+
# for ns in [1, 2, 4]
|
|
149
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
150
|
+
# key=['N'])
|
|
151
|
+
@triton.jit
|
|
152
|
+
def _grpo_loss_bwd_kernel(
|
|
153
|
+
DLOSS,
|
|
154
|
+
DLOGITS,
|
|
155
|
+
LOGITS,
|
|
156
|
+
OLD_LOGP,
|
|
157
|
+
REF_LOGP,
|
|
158
|
+
INPUT_IDS,
|
|
159
|
+
ADVANTAGES,
|
|
160
|
+
COMPLETION_MASK,
|
|
161
|
+
LSE,
|
|
162
|
+
TEMPERATURE,
|
|
163
|
+
BETA: tl.constexpr,
|
|
164
|
+
EPS_LOW,
|
|
165
|
+
EPS_HIGH,
|
|
166
|
+
loss_stride0,
|
|
167
|
+
loss_stride1,
|
|
168
|
+
L: tl.constexpr,
|
|
169
|
+
N: tl.constexpr,
|
|
170
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
171
|
+
):
|
|
172
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
173
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
174
|
+
|
|
175
|
+
DLOGITS += off_b * (L + 1) * N + off_l * N
|
|
176
|
+
if COMPLETION_MASK is not None:
|
|
177
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
178
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
179
|
+
if not_skip == 0:
|
|
180
|
+
for start in range(0, N, BLOCK_N):
|
|
181
|
+
cols = tl.arange(0, BLOCK_N) + start
|
|
182
|
+
tl.store(DLOGITS + cols, 0.0, mask=cols < N)
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
186
|
+
DLOSS += off_b * loss_stride0 + off_l * loss_stride1
|
|
187
|
+
INPUT_IDS += off_b * L + off_l
|
|
188
|
+
ADVANTAGES += off_b
|
|
189
|
+
LSE += off_b * L + off_l
|
|
190
|
+
|
|
191
|
+
dloss = tl.load(DLOSS).to(tl.float32)
|
|
192
|
+
lse = tl.load(LSE).to(tl.float32)
|
|
193
|
+
|
|
194
|
+
idx = tl.load(INPUT_IDS)
|
|
195
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
196
|
+
logp = x - lse
|
|
197
|
+
if OLD_LOGP is None:
|
|
198
|
+
old_logp = logp
|
|
199
|
+
else:
|
|
200
|
+
OLD_LOGP += off_b * L + off_l
|
|
201
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
202
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
203
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
204
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
205
|
+
per_token_loss1 = coef_1 * advantage
|
|
206
|
+
per_token_loss2 = coef_2 * advantage
|
|
207
|
+
mask = per_token_loss2 >= per_token_loss1
|
|
208
|
+
|
|
209
|
+
dlogp = -per_token_loss1 * mask
|
|
210
|
+
if BETA != 0.0:
|
|
211
|
+
REF_LOGP += off_b * L + off_l
|
|
212
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
213
|
+
dlogp += BETA * (1 - tl.exp(ref_logp - logp))
|
|
214
|
+
|
|
215
|
+
dlogp = dlogp * dloss / TEMPERATURE
|
|
216
|
+
tl.debug_barrier()
|
|
217
|
+
for start_n in tl.range(0, N, BLOCK_N):
|
|
218
|
+
cols = start_n + tl.arange(0, BLOCK_N)
|
|
219
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
|
|
220
|
+
probs = tl.exp(logits - lse)
|
|
221
|
+
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
|
|
222
|
+
tl.store(DLOGITS + cols, dlogits, mask=cols < N)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class GrpoLossFunction(torch.autograd.Function):
|
|
226
|
+
@staticmethod
|
|
227
|
+
def forward(
|
|
228
|
+
ctx,
|
|
229
|
+
logits,
|
|
230
|
+
old_logp,
|
|
231
|
+
ref_logp,
|
|
232
|
+
completion_ids,
|
|
233
|
+
advantages,
|
|
234
|
+
completion_mask,
|
|
235
|
+
temperature,
|
|
236
|
+
beta,
|
|
237
|
+
eps_low,
|
|
238
|
+
eps_high,
|
|
239
|
+
inplace,
|
|
240
|
+
):
|
|
241
|
+
assert logits.is_contiguous() and completion_ids.is_contiguous()
|
|
242
|
+
assert old_logp is None or old_logp.is_contiguous()
|
|
243
|
+
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
|
|
244
|
+
|
|
245
|
+
B, L_ADD_1, N = logits.shape
|
|
246
|
+
L = L_ADD_1 - 1
|
|
247
|
+
|
|
248
|
+
if completion_mask is not None:
|
|
249
|
+
assert completion_mask.is_contiguous()
|
|
250
|
+
|
|
251
|
+
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
|
|
252
|
+
lse = torch.zeros_like(loss)
|
|
253
|
+
is_clipped = torch.zeros_like(loss)
|
|
254
|
+
kl = torch.zeros_like(loss) if beta != 0.0 else None
|
|
255
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
|
|
256
|
+
_grpo_loss_fwd_kernel[(B, L)](
|
|
257
|
+
logits,
|
|
258
|
+
old_logp,
|
|
259
|
+
ref_logp,
|
|
260
|
+
completion_ids,
|
|
261
|
+
completion_mask,
|
|
262
|
+
advantages,
|
|
263
|
+
loss,
|
|
264
|
+
lse,
|
|
265
|
+
kl,
|
|
266
|
+
is_clipped,
|
|
267
|
+
temperature,
|
|
268
|
+
beta,
|
|
269
|
+
eps_low,
|
|
270
|
+
eps_high,
|
|
271
|
+
L,
|
|
272
|
+
N,
|
|
273
|
+
**kwargs,
|
|
274
|
+
)
|
|
275
|
+
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
|
|
276
|
+
ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
|
|
277
|
+
# return loss
|
|
278
|
+
return loss, kl, is_clipped
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def backward(ctx, *args):
|
|
282
|
+
dloss = args[0]
|
|
283
|
+
# print(dloss.shape)
|
|
284
|
+
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
|
|
285
|
+
temperature, beta, eps_low, eps_high, inplace = ctx.infos
|
|
286
|
+
B, L_ADD_1, N = logits.shape
|
|
287
|
+
L = L_ADD_1 - 1
|
|
288
|
+
dlogits = logits.data if inplace else torch.empty_like(logits)
|
|
289
|
+
kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
|
|
290
|
+
_grpo_loss_bwd_kernel[(B, L)](
|
|
291
|
+
dloss,
|
|
292
|
+
dlogits,
|
|
293
|
+
logits,
|
|
294
|
+
old_logp,
|
|
295
|
+
ref_logp,
|
|
296
|
+
completion_ids,
|
|
297
|
+
advantages,
|
|
298
|
+
completion_mask,
|
|
299
|
+
lse,
|
|
300
|
+
temperature,
|
|
301
|
+
beta,
|
|
302
|
+
eps_low,
|
|
303
|
+
eps_high,
|
|
304
|
+
*dloss.stride(),
|
|
305
|
+
L,
|
|
306
|
+
N,
|
|
307
|
+
**kwargs,
|
|
308
|
+
)
|
|
309
|
+
dlogits[:, -1, :] = 0
|
|
310
|
+
return dlogits, None, None, None, None, None, None, None, None, None, None
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@triton.jit
|
|
10
|
+
def _sparsemax_forward_kernel(
|
|
11
|
+
x_ptr,
|
|
12
|
+
x_stride_row,
|
|
13
|
+
sorted_x_ptr,
|
|
14
|
+
sorted_x_stride_row,
|
|
15
|
+
o_ptr,
|
|
16
|
+
o_stride_row,
|
|
17
|
+
n_cols,
|
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
|
19
|
+
num_warps: tl.constexpr,
|
|
20
|
+
):
|
|
21
|
+
pid_row = tl.program_id(0)
|
|
22
|
+
ptr_x_data_row = x_ptr + pid_row * x_stride_row
|
|
23
|
+
ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
|
|
24
|
+
ptr_output_row = o_ptr + pid_row * o_stride_row
|
|
25
|
+
|
|
26
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
27
|
+
mask = offs < n_cols
|
|
28
|
+
|
|
29
|
+
z_sorted_block = tl.load(
|
|
30
|
+
ptr_sorted_x_data_row + offs,
|
|
31
|
+
mask=mask,
|
|
32
|
+
other=-float("inf"),
|
|
33
|
+
cache_modifier=".ca",
|
|
34
|
+
).to(tl.float32)
|
|
35
|
+
|
|
36
|
+
z_valid = tl.where(mask, z_sorted_block, 0.0)
|
|
37
|
+
cssv = tl.cumsum(z_valid, 0)
|
|
38
|
+
|
|
39
|
+
r = (offs + 1).to(tl.float32)
|
|
40
|
+
safe_r = tl.where(mask, r, 1.0)
|
|
41
|
+
|
|
42
|
+
t_vec = (cssv - 1.0) / safe_r
|
|
43
|
+
|
|
44
|
+
support = (z_sorted_block > t_vec) & mask
|
|
45
|
+
|
|
46
|
+
k_int = tl.sum(support.to(tl.int32), 0)
|
|
47
|
+
k_clamped_int = tl.maximum(k_int, 1)
|
|
48
|
+
k = k_clamped_int.to(tl.float32)
|
|
49
|
+
|
|
50
|
+
s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
|
|
51
|
+
|
|
52
|
+
tau = (s - 1.0) / k
|
|
53
|
+
|
|
54
|
+
x_block = tl.load(
|
|
55
|
+
ptr_x_data_row + offs,
|
|
56
|
+
mask=mask,
|
|
57
|
+
other=0.0,
|
|
58
|
+
cache_modifier=".ca",
|
|
59
|
+
).to(tl.float32)
|
|
60
|
+
|
|
61
|
+
y = tl.maximum(x_block - tau, 0.0)
|
|
62
|
+
|
|
63
|
+
tl.store(
|
|
64
|
+
ptr_output_row + offs,
|
|
65
|
+
y.to(ptr_output_row.dtype.element_ty),
|
|
66
|
+
mask=mask,
|
|
67
|
+
cache_modifier=".cs",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@triton.jit
|
|
72
|
+
def _sparsemax_backward_kernel(
|
|
73
|
+
o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
|
|
74
|
+
):
|
|
75
|
+
row = tl.program_id(0)
|
|
76
|
+
o_row = o_ptr + row * stride
|
|
77
|
+
go_row = go_ptr + row * stride
|
|
78
|
+
gi_row = gi_ptr + row * stride
|
|
79
|
+
|
|
80
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
81
|
+
|
|
82
|
+
supp_cnt = tl.zeros((), tl.float32)
|
|
83
|
+
go_sum = tl.zeros((), tl.float32)
|
|
84
|
+
|
|
85
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
|
86
|
+
offs_iter = i * BLOCK_SIZE + offs
|
|
87
|
+
mask_iter = offs_iter < n_cols
|
|
88
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
|
89
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
|
90
|
+
supp = o_val > 0.0
|
|
91
|
+
go_sum += tl.sum(tl.where(supp, go_val, 0.0))
|
|
92
|
+
supp_cnt += tl.sum(supp.to(tl.float32))
|
|
93
|
+
|
|
94
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
|
95
|
+
offs_iter = i * BLOCK_SIZE + offs
|
|
96
|
+
mask_iter = offs_iter < n_cols
|
|
97
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
|
98
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
|
99
|
+
supp = o_val > 0.0
|
|
100
|
+
gi_val = tl.where(
|
|
101
|
+
supp,
|
|
102
|
+
go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
|
|
103
|
+
0.0,
|
|
104
|
+
)
|
|
105
|
+
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LigerSparsemaxFunction(torch.autograd.Function):
|
|
109
|
+
@staticmethod
|
|
110
|
+
@ensure_contiguous
|
|
111
|
+
def forward(ctx, x: torch.Tensor, dim: int):
|
|
112
|
+
if dim < 0:
|
|
113
|
+
dim += x.dim()
|
|
114
|
+
ctx.dim = dim
|
|
115
|
+
|
|
116
|
+
x_sw = x.transpose(dim, -1).contiguous()
|
|
117
|
+
n_cols = x_sw.size(-1)
|
|
118
|
+
n_rows = x_sw.numel() // n_cols
|
|
119
|
+
x_flat = x_sw.view(n_rows, n_cols)
|
|
120
|
+
|
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
122
|
+
out_flat = torch.empty_like(x_flat)
|
|
123
|
+
grid = (n_rows,)
|
|
124
|
+
|
|
125
|
+
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
|
126
|
+
|
|
127
|
+
_sparsemax_forward_kernel[grid](
|
|
128
|
+
x_flat,
|
|
129
|
+
x_flat.stride(0),
|
|
130
|
+
x_sorted_flat,
|
|
131
|
+
x_sorted_flat.stride(0),
|
|
132
|
+
out_flat,
|
|
133
|
+
out_flat.stride(0),
|
|
134
|
+
n_cols,
|
|
135
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
136
|
+
num_warps=num_warps,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
ctx.save_for_backward(out_flat)
|
|
140
|
+
return out_flat.view_as(x_sw).transpose(dim, -1)
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
@ensure_contiguous
|
|
144
|
+
def backward(ctx, grad_out: torch.Tensor):
|
|
145
|
+
(out_flat,) = ctx.saved_tensors
|
|
146
|
+
dim = ctx.dim
|
|
147
|
+
|
|
148
|
+
go_sw = grad_out.transpose(dim, -1).contiguous()
|
|
149
|
+
n_cols = go_sw.size(-1)
|
|
150
|
+
n_rows = go_sw.numel() // n_cols
|
|
151
|
+
go_flat = go_sw.view(n_rows, n_cols)
|
|
152
|
+
|
|
153
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
154
|
+
gi_flat = torch.empty_like(go_flat)
|
|
155
|
+
grid = (n_rows,)
|
|
156
|
+
|
|
157
|
+
_sparsemax_backward_kernel[grid](
|
|
158
|
+
out_flat,
|
|
159
|
+
go_flat,
|
|
160
|
+
gi_flat,
|
|
161
|
+
out_flat.stride(0),
|
|
162
|
+
n_cols,
|
|
163
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
164
|
+
num_warps=num_warps,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return gi_flat.view_as(go_sw).transpose(dim, -1), None
|
|
@@ -14,6 +14,7 @@ from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
|
14
14
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
15
15
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
|
16
16
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
|
17
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
|
|
17
18
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
|
18
19
|
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
|
|
19
20
|
|
|
@@ -26,6 +27,7 @@ if TYPE_CHECKING:
|
|
|
26
27
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
27
28
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
|
|
28
29
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
|
|
30
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
|
|
29
31
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
30
32
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
31
33
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
|
|
@@ -38,6 +40,8 @@ if TYPE_CHECKING:
|
|
|
38
40
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
39
41
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
|
|
40
42
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
43
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
|
|
44
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
|
|
41
45
|
|
|
42
46
|
|
|
43
47
|
# Check if 'transformers' is installed
|
|
@@ -79,6 +83,7 @@ def __getattr__(name: str):
|
|
|
79
83
|
"apply_liger_kernel_to_gemma2",
|
|
80
84
|
"apply_liger_kernel_to_gemma3",
|
|
81
85
|
"apply_liger_kernel_to_gemma3_text",
|
|
86
|
+
"apply_liger_kernel_to_glm4",
|
|
82
87
|
"apply_liger_kernel_to_granite",
|
|
83
88
|
"apply_liger_kernel_to_llama",
|
|
84
89
|
"apply_liger_kernel_to_llava",
|
|
@@ -91,6 +96,8 @@ def __getattr__(name: str):
|
|
|
91
96
|
"apply_liger_kernel_to_qwen2",
|
|
92
97
|
"apply_liger_kernel_to_qwen2_5_vl",
|
|
93
98
|
"apply_liger_kernel_to_qwen2_vl",
|
|
99
|
+
"apply_liger_kernel_to_qwen3",
|
|
100
|
+
"apply_liger_kernel_to_qwen3_moe",
|
|
94
101
|
}
|
|
95
102
|
|
|
96
103
|
if name in monkey_patch_symbols:
|
|
@@ -114,6 +121,7 @@ __all__ = [
|
|
|
114
121
|
"liger_rotary_pos_emb",
|
|
115
122
|
"LigerBlockSparseTop2MLP",
|
|
116
123
|
"LigerPhi3SwiGLUMLP",
|
|
124
|
+
"LigerQwen3MoeSwiGLUMLP",
|
|
117
125
|
"LigerSwiGLUMLP",
|
|
118
126
|
"LigerTVDLoss",
|
|
119
127
|
]
|
|
@@ -129,6 +137,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
129
137
|
"apply_liger_kernel_to_gemma2",
|
|
130
138
|
"apply_liger_kernel_to_gemma3",
|
|
131
139
|
"apply_liger_kernel_to_gemma3_text",
|
|
140
|
+
"apply_liger_kernel_to_glm4",
|
|
132
141
|
"apply_liger_kernel_to_granite",
|
|
133
142
|
"apply_liger_kernel_to_llama",
|
|
134
143
|
"apply_liger_kernel_to_llava",
|
|
@@ -141,5 +150,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
141
150
|
"apply_liger_kernel_to_qwen2",
|
|
142
151
|
"apply_liger_kernel_to_qwen2_5_vl",
|
|
143
152
|
"apply_liger_kernel_to_qwen2_vl",
|
|
153
|
+
"apply_liger_kernel_to_qwen3",
|
|
154
|
+
"apply_liger_kernel_to_qwen3_moe",
|
|
144
155
|
]
|
|
145
156
|
)
|
liger_kernel/transformers/dyt.py
CHANGED
|
@@ -5,16 +5,18 @@ from liger_kernel.ops.dyt import LigerDyTFunction
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class LigerDyT(nn.Module):
|
|
8
|
-
def __init__(self, hidden_size, init_alpha=0.5):
|
|
8
|
+
def __init__(self, hidden_size, beta=True, init_alpha=0.5):
|
|
9
9
|
super().__init__()
|
|
10
10
|
self.hidden_size = hidden_size
|
|
11
11
|
self.init_alpha = init_alpha
|
|
12
12
|
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
13
13
|
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
|
14
|
-
self.beta =
|
|
14
|
+
self.beta = None
|
|
15
|
+
if beta:
|
|
16
|
+
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
|
15
17
|
|
|
16
18
|
def forward(self, x):
|
|
17
19
|
return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
|
|
18
20
|
|
|
19
21
|
def extra_repr(self):
|
|
20
|
-
return f"{self.hidden_size}, init_alpha={self.init_alpha}"
|
|
22
|
+
return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _FSDPForwardRedirection:
|
|
8
|
+
"""
|
|
9
|
+
Modified based on
|
|
10
|
+
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
|
|
11
|
+
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
|
|
12
|
+
post-forward can be properly executed around the method call.
|
|
13
|
+
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
|
|
14
|
+
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
|
|
15
|
+
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
|
|
16
|
+
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
|
|
17
|
+
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
|
|
18
|
+
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
|
|
19
|
+
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __call__(
|
|
23
|
+
self,
|
|
24
|
+
wrapper_module: FullyShardedDataParallel,
|
|
25
|
+
method: Callable,
|
|
26
|
+
*args: Any,
|
|
27
|
+
**kwargs: Any,
|
|
28
|
+
):
|
|
29
|
+
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
|
|
30
|
+
Args:
|
|
31
|
+
wrapper_module: The module that has `original_module` wrapped.
|
|
32
|
+
original_module: The module that was wrapped inside `wrapper_module`.
|
|
33
|
+
method_name: The name of the method that should be called on the `original_module` after inputs get
|
|
34
|
+
redirected through the `wrapper_module`'s `forward` method.
|
|
35
|
+
*args: The positional arguments to the method `method_name`. They will get passed to a patched
|
|
36
|
+
`forward` method instead.
|
|
37
|
+
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
|
|
38
|
+
`forward` method instead.
|
|
39
|
+
"""
|
|
40
|
+
assert isinstance(wrapper_module, FullyShardedDataParallel)
|
|
41
|
+
original_module = wrapper_module._fsdp_wrapped_module
|
|
42
|
+
original_forward = original_module.forward
|
|
43
|
+
|
|
44
|
+
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
|
45
|
+
# Unpatch ourselves immediately before calling the method `method_name`
|
|
46
|
+
# because itself may want to call the real `forward`
|
|
47
|
+
original_module.forward = original_forward # type: ignore[method-assign]
|
|
48
|
+
# Call the actual method e.g. `.training_step(...)`
|
|
49
|
+
out = method(*_args, **_kwargs)
|
|
50
|
+
return out
|
|
51
|
+
|
|
52
|
+
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
|
53
|
+
original_module.forward = wrapped_forward # type: ignore[method-assign]
|
|
54
|
+
wrapper_output = wrapper_module(*args, **kwargs)
|
|
55
|
+
return wrapper_output
|
|
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
|
12
12
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
13
13
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
14
14
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
15
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
|
15
16
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
16
17
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
17
18
|
|
|
@@ -159,6 +160,13 @@ def liger_kl_div(
|
|
|
159
160
|
)
|
|
160
161
|
|
|
161
162
|
|
|
163
|
+
def liger_sparsemax(
|
|
164
|
+
input,
|
|
165
|
+
dim: int = -1,
|
|
166
|
+
):
|
|
167
|
+
return LigerSparsemaxFunction.apply(input, dim)
|
|
168
|
+
|
|
169
|
+
|
|
162
170
|
def liger_tvd(
|
|
163
171
|
input,
|
|
164
172
|
target,
|
|
@@ -23,8 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
23
23
|
assert reduction in {
|
|
24
24
|
"mean",
|
|
25
25
|
"sum",
|
|
26
|
-
|
|
27
|
-
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
26
|
+
}, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
|
|
28
27
|
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
29
28
|
self.ce_weight = ce_weight
|
|
30
29
|
self.ignore_index = ignore_index
|