liger-kernel 0.6.3__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- liger_kernel/ops/cross_entropy.py +59 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +30 -4
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +84 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +19 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +24 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +25 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +22 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +398 -20
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +4 -1
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/RECORD +55 -48
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
|
@@ -27,8 +27,12 @@ def fused_linear_cross_entropy_forward(
|
|
|
27
27
|
return_z_loss=False,
|
|
28
28
|
accum_dtype=None,
|
|
29
29
|
use_token_scaling=False,
|
|
30
|
+
return_token_accuracy=False,
|
|
30
31
|
):
|
|
31
32
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
33
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
34
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
35
|
+
)
|
|
32
36
|
device = _input.device
|
|
33
37
|
|
|
34
38
|
input_requires_grad = _input.requires_grad
|
|
@@ -58,9 +62,13 @@ def fused_linear_cross_entropy_forward(
|
|
|
58
62
|
else:
|
|
59
63
|
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
60
64
|
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
65
|
+
else:
|
|
66
|
+
grad_weight = None
|
|
67
|
+
grad_bias = None
|
|
61
68
|
|
|
62
69
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
63
70
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
71
|
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
|
64
72
|
|
|
65
73
|
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
66
74
|
target_mask = target != ignore_index
|
|
@@ -126,6 +134,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
126
134
|
# unreduced loss
|
|
127
135
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
128
136
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
137
|
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
|
129
138
|
|
|
130
139
|
# ensure _input and target are contiguous
|
|
131
140
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -141,6 +150,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
141
150
|
loss_ptr=loss_1d_slice,
|
|
142
151
|
z_loss_ptr=z_loss_1d_slice,
|
|
143
152
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
153
|
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
|
154
|
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
|
155
|
+
if return_token_accuracy
|
|
156
|
+
else 0, # always 1 if accuracy is enabled
|
|
144
157
|
n_cols=V,
|
|
145
158
|
n_non_ignore=total_n_non_ignore,
|
|
146
159
|
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
@@ -151,6 +164,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
151
164
|
reduction=reduction,
|
|
152
165
|
softcap=softcap,
|
|
153
166
|
RETURN_Z_LOSS=return_z_loss,
|
|
167
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
154
168
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
155
169
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
156
170
|
HAS_GRADIENTS=input_requires_grad,
|
|
@@ -167,6 +181,8 @@ def fused_linear_cross_entropy_forward(
|
|
|
167
181
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
168
182
|
if return_z_loss:
|
|
169
183
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
184
|
+
if return_token_accuracy:
|
|
185
|
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
|
170
186
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
171
187
|
|
|
172
188
|
# Apply token scaling to gradients if requested
|
|
@@ -198,15 +214,18 @@ def fused_linear_cross_entropy_forward(
|
|
|
198
214
|
# Return per-token losses
|
|
199
215
|
loss = loss_1d
|
|
200
216
|
z_loss = z_loss_1d if return_z_loss else None
|
|
217
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
201
218
|
else:
|
|
202
219
|
loss = torch.sum(loss_1d)
|
|
203
220
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
221
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
222
|
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
|
204
223
|
|
|
205
224
|
# Cast back to original dtype
|
|
206
225
|
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
207
226
|
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
208
227
|
|
|
209
|
-
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
228
|
+
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
|
|
210
229
|
|
|
211
230
|
|
|
212
231
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
@@ -274,6 +293,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
274
293
|
return_z_loss: bool = False,
|
|
275
294
|
accum_dtype=None,
|
|
276
295
|
use_token_scaling: bool = False,
|
|
296
|
+
return_token_accuracy: bool = False,
|
|
277
297
|
):
|
|
278
298
|
"""
|
|
279
299
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -297,9 +317,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
297
317
|
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
298
318
|
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
299
319
|
Default: False.
|
|
320
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
300
321
|
"""
|
|
301
322
|
|
|
302
|
-
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
323
|
+
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
303
324
|
_input=_input,
|
|
304
325
|
weight=weight,
|
|
305
326
|
target=target,
|
|
@@ -313,6 +334,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
313
334
|
return_z_loss=return_z_loss,
|
|
314
335
|
accum_dtype=accum_dtype,
|
|
315
336
|
use_token_scaling=use_token_scaling,
|
|
337
|
+
return_token_accuracy=return_token_accuracy,
|
|
316
338
|
)
|
|
317
339
|
# downcast to dtype and store for backward
|
|
318
340
|
ctx.save_for_backward(
|
|
@@ -321,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
321
343
|
grad_bias.detach() if bias is not None else None,
|
|
322
344
|
)
|
|
323
345
|
ctx.return_z_loss = return_z_loss
|
|
324
|
-
|
|
346
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
347
|
+
return loss, z_loss, token_accuracy
|
|
325
348
|
|
|
326
349
|
@staticmethod
|
|
327
350
|
@amp_custom_bwd
|
|
328
|
-
def backward(ctx, grad_output, grad_output2):
|
|
351
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
329
352
|
if ctx.return_z_loss:
|
|
330
353
|
del grad_output2 # z_loss is only for logging
|
|
354
|
+
if ctx.return_token_accuracy:
|
|
355
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
331
356
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
332
357
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
333
358
|
grad_output, grad_input, grad_weight, grad_bias
|
|
@@ -346,4 +371,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
346
371
|
None,
|
|
347
372
|
None,
|
|
348
373
|
None, # use_token_scaling
|
|
374
|
+
None, # return_token_accuracy
|
|
349
375
|
)
|
liger_kernel/ops/grpo_loss.py
CHANGED
|
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
|
|
|
128
128
|
per_token_loss1 = coef_1 * advantage
|
|
129
129
|
per_token_loss2 = coef_2 * advantage
|
|
130
130
|
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
-
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
132
134
|
|
|
133
135
|
if BETA != 0.0:
|
|
134
136
|
REF_LOGP += off_b * L + off_l
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import operator
|
|
2
3
|
|
|
3
4
|
import torch
|
|
@@ -85,68 +86,87 @@ def _layer_norm_forward_kernel(
|
|
|
85
86
|
@triton.jit
|
|
86
87
|
def _layer_norm_backward_kernel(
|
|
87
88
|
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
89
|
+
stride_x, # stride of each row in input
|
|
88
90
|
W_ptr, # pointer to weights, shape (n_cols,)
|
|
89
91
|
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
92
|
+
stride_mean, # stride of each row in mean
|
|
90
93
|
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
94
|
+
stride_rstd, # stride of each row in rstd
|
|
91
95
|
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
96
|
+
stride_dx, # stride of each row in input grad
|
|
92
97
|
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
98
|
+
stride_dw, # stride of each row in weights grad
|
|
93
99
|
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
100
|
+
stride_db, # stride of each row in bias grad
|
|
94
101
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
95
|
-
stride_x, # stride of each row in input
|
|
96
|
-
stride_dx, # stride of each row in input grad
|
|
97
102
|
stride_dy, # stride of each row in output grad
|
|
103
|
+
n_rows,
|
|
98
104
|
n_cols,
|
|
105
|
+
rows_per_program: tl.constexpr,
|
|
99
106
|
BLOCK_SIZE: tl.constexpr,
|
|
100
|
-
dtype: tl.constexpr,
|
|
101
|
-
atomic_dtype: tl.constexpr,
|
|
102
107
|
):
|
|
103
108
|
"""
|
|
104
109
|
References:
|
|
105
110
|
https://arxiv.org/abs/1607.06450
|
|
106
111
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
107
112
|
"""
|
|
108
|
-
|
|
113
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
114
|
+
row_start = row_block_id * rows_per_program
|
|
115
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
109
116
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
110
117
|
mask = cols < n_cols
|
|
111
118
|
|
|
119
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
120
|
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
121
|
+
|
|
112
122
|
# Pre-load weights once (same optimization as forward pass)
|
|
113
123
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
114
124
|
w_f32 = w.to(tl.float32)
|
|
115
125
|
|
|
116
126
|
# Calculate pointers for this specific row
|
|
117
|
-
row_X_ptr = X_ptr +
|
|
118
|
-
row_DX_ptr = DX_ptr +
|
|
119
|
-
row_DY_ptr = DY_ptr +
|
|
120
|
-
row_Mean_ptr = Mean_ptr +
|
|
121
|
-
row_RSTD_ptr = RSTD_ptr +
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
127
|
+
row_X_ptr = X_ptr + row_start * stride_x
|
|
128
|
+
row_DX_ptr = DX_ptr + row_start * stride_dx
|
|
129
|
+
row_DY_ptr = DY_ptr + row_start * stride_dy
|
|
130
|
+
row_Mean_ptr = Mean_ptr + row_start
|
|
131
|
+
row_RSTD_ptr = RSTD_ptr + row_start
|
|
132
|
+
|
|
133
|
+
for _ in range(row_start, row_end):
|
|
134
|
+
# Load data for this row
|
|
135
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
136
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
137
|
+
mean = tl.load(row_Mean_ptr)
|
|
138
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
139
|
+
|
|
140
|
+
# Convert to fp32 for numerical stability
|
|
141
|
+
x_f32 = x.to(tl.float32)
|
|
142
|
+
dy_f32 = dy.to(tl.float32)
|
|
143
|
+
mean_f32 = mean.to(tl.float32)
|
|
144
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
145
|
+
|
|
146
|
+
# Compute backward pass for this row
|
|
147
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
148
|
+
wdy = w_f32 * dy_f32
|
|
149
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
150
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
151
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
152
|
+
|
|
153
|
+
# Store input gradient
|
|
154
|
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
|
155
|
+
|
|
156
|
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
|
157
|
+
dw = dy_f32 * x_hat
|
|
158
|
+
db = dy_f32
|
|
159
|
+
dW_row += dw
|
|
160
|
+
db_row += db
|
|
161
|
+
|
|
162
|
+
row_X_ptr += stride_x
|
|
163
|
+
row_DX_ptr += stride_dx
|
|
164
|
+
row_DY_ptr += stride_dy
|
|
165
|
+
row_Mean_ptr += stride_mean
|
|
166
|
+
row_RSTD_ptr += stride_rstd
|
|
167
|
+
|
|
168
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
169
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
150
170
|
|
|
151
171
|
|
|
152
172
|
def layer_norm_forward(X, W, B, eps):
|
|
@@ -228,31 +248,25 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
228
248
|
dY = dY.view(-1, dim)
|
|
229
249
|
n_rows, n_cols = dY.shape
|
|
230
250
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
251
|
+
sm_count = 1
|
|
252
|
+
if X.device.type == "cuda":
|
|
253
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
254
|
+
elif X.device.type == "xpu":
|
|
255
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
256
|
+
|
|
257
|
+
# fp32 for numerical stability especially.
|
|
258
|
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
259
|
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
237
260
|
|
|
238
261
|
# Calculate optimal block size and warp configuration
|
|
239
262
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
240
263
|
if n_cols > BLOCK_SIZE:
|
|
241
264
|
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
265
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
266
|
+
grid = (sm_count,)
|
|
242
267
|
|
|
243
|
-
#
|
|
244
|
-
|
|
245
|
-
tl.float32
|
|
246
|
-
if X.dtype == torch.float32
|
|
247
|
-
else tl.bfloat16
|
|
248
|
-
if X.dtype == torch.bfloat16
|
|
249
|
-
else tl.float16
|
|
250
|
-
if X.dtype == torch.float16
|
|
251
|
-
else tl.float32 # fallback
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# Use float32 for atomic operations if bfloat16 is not supported
|
|
255
|
-
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
268
|
+
# Allocate gradient tensors
|
|
269
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
256
270
|
|
|
257
271
|
kernel_args = {"num_warps": num_warps}
|
|
258
272
|
# XPU-specific optimization
|
|
@@ -260,28 +274,33 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
260
274
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
261
275
|
|
|
262
276
|
# Launch kernel with one thread block per row for optimal performance
|
|
263
|
-
grid = (n_rows,)
|
|
264
277
|
_layer_norm_backward_kernel[grid](
|
|
265
278
|
X,
|
|
279
|
+
X.stride(0),
|
|
266
280
|
W,
|
|
267
281
|
Mean,
|
|
282
|
+
Mean.stride(0),
|
|
268
283
|
RSTD,
|
|
284
|
+
RSTD.stride(0),
|
|
269
285
|
DX,
|
|
270
|
-
DW,
|
|
271
|
-
DB,
|
|
272
|
-
dY,
|
|
273
|
-
X.stride(0),
|
|
274
286
|
DX.stride(0),
|
|
287
|
+
_DW,
|
|
288
|
+
_DW.stride(0),
|
|
289
|
+
_DB,
|
|
290
|
+
_DB.stride(0),
|
|
291
|
+
dY,
|
|
275
292
|
dY.stride(0),
|
|
293
|
+
n_rows,
|
|
276
294
|
n_cols,
|
|
295
|
+
rows_per_program=rows_per_program,
|
|
277
296
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
278
|
-
dtype=triton_dtype,
|
|
279
|
-
atomic_dtype=atomic_dtype,
|
|
280
297
|
**kernel_args,
|
|
281
298
|
)
|
|
282
299
|
|
|
283
300
|
DX = DX.view(*shape)
|
|
284
|
-
|
|
301
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
302
|
+
DB = _DB.sum(dim=0).to(B.dtype)
|
|
303
|
+
return DX, DW, DB
|
|
285
304
|
|
|
286
305
|
|
|
287
306
|
class LigerLayerNormFunction(torch.autograd.Function):
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
from typing import List
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LigerTiledMLPFunction(torch.autograd.Function):
|
|
13
|
+
"""
|
|
14
|
+
Based on DeepSpeed's TiledMLP:
|
|
15
|
+
https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
|
|
16
|
+
|
|
17
|
+
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
|
|
18
|
+
when using very long sequence lengths.
|
|
19
|
+
|
|
20
|
+
This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
|
|
21
|
+
And if you're using activation checkpointing it then occurs thrice.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
fn: the function to call on sharded inputs (e.g., mlp.forward)
|
|
25
|
+
mlp_module: the MLP nn.Module object
|
|
26
|
+
x: the input to MLP.forward (hidden_states)
|
|
27
|
+
shards: how many shards to use
|
|
28
|
+
compute_params: a list of weights engaged in the compute
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
the computed hidden_states
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
@ensure_contiguous
|
|
36
|
+
def forward(
|
|
37
|
+
ctx,
|
|
38
|
+
fn: Callable,
|
|
39
|
+
mlp_module: torch.nn.Module,
|
|
40
|
+
x: torch.Tensor,
|
|
41
|
+
shards: int,
|
|
42
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
ctx.fn = fn
|
|
45
|
+
ctx.mlp_module = mlp_module
|
|
46
|
+
ctx.shards = shards
|
|
47
|
+
ctx.save_for_backward(x)
|
|
48
|
+
|
|
49
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
50
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
|
|
53
|
+
output_unsharded = torch.cat(output_shards, dim=-2)
|
|
54
|
+
|
|
55
|
+
return output_unsharded
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
@ensure_contiguous
|
|
59
|
+
def backward(ctx, *grads) -> tuple:
|
|
60
|
+
fn = ctx.fn
|
|
61
|
+
(x,) = ctx.saved_tensors
|
|
62
|
+
mlp_module = ctx.mlp_module
|
|
63
|
+
shards = ctx.shards
|
|
64
|
+
|
|
65
|
+
x_requires_grad = x.requires_grad
|
|
66
|
+
x = x.detach()
|
|
67
|
+
# detach() unsets x.requires_grad, so restore it
|
|
68
|
+
x.requires_grad_(x_requires_grad)
|
|
69
|
+
|
|
70
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
71
|
+
hidden_size = x.shape[-1]
|
|
72
|
+
x_shape_orig = x.shape
|
|
73
|
+
|
|
74
|
+
# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
|
|
75
|
+
x = x.view(-1, hidden_size)
|
|
76
|
+
incoming_grad = grads[0].view(-1, hidden_size)
|
|
77
|
+
x_grad = torch.zeros_like(x)
|
|
78
|
+
|
|
79
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=0))
|
|
80
|
+
|
|
81
|
+
for i, x_shard in enumerate(x_shards):
|
|
82
|
+
x_shard.requires_grad_(x_requires_grad)
|
|
83
|
+
|
|
84
|
+
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
|
|
85
|
+
shard_step = x_shards[i].shape[0]
|
|
86
|
+
shard_offset = i * x_shards[0].shape[0]
|
|
87
|
+
|
|
88
|
+
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
89
|
+
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
90
|
+
|
|
91
|
+
with torch.enable_grad():
|
|
92
|
+
output = fn(mlp_module, x_shard)
|
|
93
|
+
torch.autograd.backward(output, incoming_grad_shard)
|
|
94
|
+
|
|
95
|
+
# unflatten
|
|
96
|
+
x_grad = x_grad.view(x_shape_orig)
|
|
97
|
+
|
|
98
|
+
return (None, None, x_grad, None, None)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def apply_tiled_mlp(
|
|
102
|
+
fn: Callable,
|
|
103
|
+
mlp_module: torch.nn.Module,
|
|
104
|
+
x: torch.Tensor,
|
|
105
|
+
num_shards: Optional[int] = None,
|
|
106
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
107
|
+
) -> torch.Tensor:
|
|
108
|
+
"""
|
|
109
|
+
Apply tiled MLP computation for memory efficiency.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
|
|
113
|
+
mlp_module: the MLP nn.Module object
|
|
114
|
+
x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
115
|
+
num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
|
|
116
|
+
compute_params: list of parameters for DeepSpeed ZeRO optimization
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
output tensor with the same shape as input
|
|
120
|
+
"""
|
|
121
|
+
if num_shards is None:
|
|
122
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
123
|
+
hidden_size = x.shape[-1]
|
|
124
|
+
seqlen = x.shape[-2]
|
|
125
|
+
num_shards = math.ceil(seqlen / hidden_size)
|
|
126
|
+
|
|
127
|
+
# Ensure num_shards is at least 1
|
|
128
|
+
num_shards = max(1, num_shards)
|
|
129
|
+
|
|
130
|
+
return LigerTiledMLPFunction.apply(
|
|
131
|
+
fn,
|
|
132
|
+
mlp_module,
|
|
133
|
+
x,
|
|
134
|
+
num_shards,
|
|
135
|
+
compute_params,
|
|
136
|
+
)
|
|
@@ -24,6 +24,8 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F4
|
|
|
24
24
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
|
25
25
|
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
|
|
26
26
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
|
27
|
+
from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
|
|
28
|
+
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
|
|
27
29
|
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
|
|
28
30
|
|
|
29
31
|
# Static-only imports for IDEs and type checkers
|
|
@@ -40,6 +42,8 @@ if TYPE_CHECKING:
|
|
|
40
42
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
|
|
41
43
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
|
|
42
44
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
45
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
|
|
46
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
|
|
43
47
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
|
|
44
48
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
45
49
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
|
|
@@ -48,6 +52,7 @@ if TYPE_CHECKING:
|
|
|
48
52
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
|
49
53
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
|
50
54
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
|
|
55
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
|
|
51
56
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
|
|
52
57
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
|
53
58
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
@@ -56,6 +61,8 @@ if TYPE_CHECKING:
|
|
|
56
61
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
|
|
57
62
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
|
|
58
63
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
|
|
64
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
|
|
65
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
|
|
59
66
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
|
|
60
67
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
|
|
61
68
|
|
|
@@ -112,6 +119,7 @@ def __getattr__(name: str):
|
|
|
112
119
|
"apply_liger_kernel_to_mixtral",
|
|
113
120
|
"apply_liger_kernel_to_mllama",
|
|
114
121
|
"apply_liger_kernel_to_olmo2",
|
|
122
|
+
"apply_liger_kernel_to_olmo3",
|
|
115
123
|
"apply_liger_kernel_to_paligemma",
|
|
116
124
|
"apply_liger_kernel_to_phi3",
|
|
117
125
|
"apply_liger_kernel_to_qwen2",
|
|
@@ -120,8 +128,12 @@ def __getattr__(name: str):
|
|
|
120
128
|
"apply_liger_kernel_to_qwen3",
|
|
121
129
|
"apply_liger_kernel_to_qwen3_moe",
|
|
122
130
|
"apply_liger_kernel_to_qwen3_next",
|
|
131
|
+
"apply_liger_kernel_to_qwen3_vl",
|
|
132
|
+
"apply_liger_kernel_to_qwen3_vl_moe",
|
|
123
133
|
"apply_liger_kernel_to_smollm3",
|
|
124
134
|
"apply_liger_kernel_to_smolvlm",
|
|
135
|
+
"apply_liger_kernel_to_hunyuan_v1_dense",
|
|
136
|
+
"apply_liger_kernel_to_hunyuan_v1_moe",
|
|
125
137
|
}
|
|
126
138
|
|
|
127
139
|
if name in monkey_patch_symbols:
|
|
@@ -151,6 +163,8 @@ __all__ = [
|
|
|
151
163
|
"LigerPhi3SwiGLUMLP",
|
|
152
164
|
"LigerQwen3MoeSwiGLUMLP",
|
|
153
165
|
"LigerSwiGLUMLP",
|
|
166
|
+
"LigerTiledGEGLUMLP",
|
|
167
|
+
"LigerTiledSwiGLUMLP",
|
|
154
168
|
"LigerTVDLoss",
|
|
155
169
|
"LigerKLDIVLoss",
|
|
156
170
|
"LigerMultiTokenAttention",
|
|
@@ -182,6 +196,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
182
196
|
"apply_liger_kernel_to_mixtral",
|
|
183
197
|
"apply_liger_kernel_to_mllama",
|
|
184
198
|
"apply_liger_kernel_to_olmo2",
|
|
199
|
+
"apply_liger_kernel_to_olmo3",
|
|
185
200
|
"apply_liger_kernel_to_paligemma",
|
|
186
201
|
"apply_liger_kernel_to_phi3",
|
|
187
202
|
"apply_liger_kernel_to_qwen2",
|
|
@@ -190,7 +205,11 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
190
205
|
"apply_liger_kernel_to_qwen3",
|
|
191
206
|
"apply_liger_kernel_to_qwen3_moe",
|
|
192
207
|
"apply_liger_kernel_to_qwen3_next",
|
|
208
|
+
"apply_liger_kernel_to_qwen3_vl",
|
|
209
|
+
"apply_liger_kernel_to_qwen3_vl_moe",
|
|
193
210
|
"apply_liger_kernel_to_smollm3",
|
|
194
211
|
"apply_liger_kernel_to_smolvlm",
|
|
212
|
+
"apply_liger_kernel_to_hunyuan_v1_dense",
|
|
213
|
+
"apply_liger_kernel_to_hunyuan_v1_moe",
|
|
195
214
|
]
|
|
196
215
|
)
|
|
@@ -3,6 +3,7 @@ from typing import Optional
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
6
|
+
from liger_kernel.transformers.functional import CrossEntropyOutput
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class LigerCrossEntropyLoss(torch.nn.Module):
|
|
@@ -15,6 +16,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
15
16
|
reduction: str = "mean",
|
|
16
17
|
softcap: Optional[float] = None,
|
|
17
18
|
return_z_loss: bool = False,
|
|
19
|
+
return_token_accuracy: bool = False,
|
|
18
20
|
):
|
|
19
21
|
super().__init__()
|
|
20
22
|
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
|
|
@@ -33,9 +35,10 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
33
35
|
self.reduction = reduction
|
|
34
36
|
self.softcap = softcap
|
|
35
37
|
self.return_z_loss = return_z_loss
|
|
38
|
+
self.return_token_accuracy = return_token_accuracy
|
|
36
39
|
|
|
37
40
|
def forward(self, _input: torch.Tensor, target: torch.Tensor):
|
|
38
|
-
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
41
|
+
loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
|
|
39
42
|
_input,
|
|
40
43
|
target,
|
|
41
44
|
self.weight,
|
|
@@ -45,7 +48,9 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
45
48
|
self.reduction,
|
|
46
49
|
self.softcap,
|
|
47
50
|
self.return_z_loss,
|
|
51
|
+
self.return_token_accuracy,
|
|
48
52
|
)
|
|
49
|
-
if not self.return_z_loss:
|
|
53
|
+
if not self.return_z_loss and not self.return_token_accuracy:
|
|
50
54
|
return loss
|
|
51
|
-
|
|
55
|
+
|
|
56
|
+
return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
|