liger-kernel 0.6.2__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
- liger_kernel/chunked_loss/grpo_loss.py +38 -4
- liger_kernel/chunked_loss/jsd_loss.py +5 -2
- liger_kernel/ops/cross_entropy.py +59 -53
- liger_kernel/ops/fused_linear_cross_entropy.py +68 -10
- liger_kernel/ops/layer_norm.py +4 -6
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/transformers/__init__.py +17 -0
- liger_kernel/transformers/functional.py +7 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +5 -1
- liger_kernel/transformers/model/falcon_h1.py +108 -0
- liger_kernel/transformers/model/gemma.py +2 -1
- liger_kernel/transformers/model/gemma2.py +8 -2
- liger_kernel/transformers/model/gemma3.py +27 -2
- liger_kernel/transformers/model/glm4.py +2 -1
- liger_kernel/transformers/model/glm4v.py +3 -2
- liger_kernel/transformers/model/glm4v_moe.py +153 -0
- liger_kernel/transformers/model/internvl.py +150 -0
- liger_kernel/transformers/model/llama.py +2 -1
- liger_kernel/transformers/model/llama4.py +2 -1
- liger_kernel/transformers/model/llava.py +6 -2
- liger_kernel/transformers/model/loss_utils.py +1 -0
- liger_kernel/transformers/model/mistral.py +2 -1
- liger_kernel/transformers/model/mixtral.py +8 -2
- liger_kernel/transformers/model/mllama.py +2 -1
- liger_kernel/transformers/model/olmo2.py +2 -1
- liger_kernel/transformers/model/paligemma.py +19 -0
- liger_kernel/transformers/model/phi3.py +2 -1
- liger_kernel/transformers/model/qwen2.py +2 -1
- liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
- liger_kernel/transformers/model/qwen2_vl.py +7 -2
- liger_kernel/transformers/model/qwen3.py +2 -1
- liger_kernel/transformers/model/qwen3_moe.py +8 -2
- liger_kernel/transformers/model/qwen3_next.py +134 -0
- liger_kernel/transformers/model/smollm3.py +2 -1
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +452 -3
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +13 -10
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +46 -39
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
try:
|
|
13
|
+
from triton.language.extra.libdevice import rsqrt
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
16
|
+
else:
|
|
17
|
+
from triton.language.math import rsqrt
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@triton.jit
|
|
21
|
+
def _poly_norm_forward_kernel(
|
|
22
|
+
Y_ptr,
|
|
23
|
+
Y_row_stride,
|
|
24
|
+
X_ptr,
|
|
25
|
+
X_row_stride,
|
|
26
|
+
W_ptr, # weight: [3] for [w0, w1, w2]
|
|
27
|
+
B_ptr, # bias: scalar
|
|
28
|
+
RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
|
|
29
|
+
RSTD_row_stride,
|
|
30
|
+
n_cols,
|
|
31
|
+
eps,
|
|
32
|
+
BLOCK_SIZE: tl.constexpr,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
PolyNorm formula:
|
|
36
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
37
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
38
|
+
|
|
39
|
+
Reference:
|
|
40
|
+
1. https://github.com/BryceZhuo/PolyCom/
|
|
41
|
+
2. https://arxiv.org/pdf/2411.03884
|
|
42
|
+
|
|
43
|
+
Cache rstd values for backward pass
|
|
44
|
+
"""
|
|
45
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
46
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
+
mask = col_offsets < n_cols
|
|
48
|
+
|
|
49
|
+
# Load pointers
|
|
50
|
+
Y_ptr += row_idx * Y_row_stride
|
|
51
|
+
X_ptr += row_idx * X_row_stride
|
|
52
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
53
|
+
|
|
54
|
+
# Load input row
|
|
55
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
|
56
|
+
|
|
57
|
+
# Load weights and bias
|
|
58
|
+
w0 = tl.load(W_ptr + 0)
|
|
59
|
+
w1 = tl.load(W_ptr + 1)
|
|
60
|
+
w2 = tl.load(W_ptr + 2)
|
|
61
|
+
b = tl.load(B_ptr)
|
|
62
|
+
|
|
63
|
+
# Compute x³, x², x
|
|
64
|
+
X_pow3 = X_row * X_row * X_row
|
|
65
|
+
X_pow2 = X_row * X_row
|
|
66
|
+
X_pow1 = X_row
|
|
67
|
+
|
|
68
|
+
# Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
|
|
69
|
+
mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
|
|
70
|
+
rstd_3 = rsqrt(mean_square_3 + eps)
|
|
71
|
+
norm_x3 = X_pow3 * rstd_3
|
|
72
|
+
|
|
73
|
+
# Compute norm(x²)
|
|
74
|
+
mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
|
|
75
|
+
rstd_2 = rsqrt(mean_square_2 + eps)
|
|
76
|
+
norm_x2 = X_pow2 * rstd_2
|
|
77
|
+
|
|
78
|
+
# Compute norm(x)
|
|
79
|
+
mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
|
|
80
|
+
rstd_1 = rsqrt(mean_square_1 + eps)
|
|
81
|
+
norm_x1 = X_pow1 * rstd_1
|
|
82
|
+
|
|
83
|
+
# Cache rstd values for backward
|
|
84
|
+
tl.store(RSTD_ptr + 0, rstd_3)
|
|
85
|
+
tl.store(RSTD_ptr + 1, rstd_2)
|
|
86
|
+
tl.store(RSTD_ptr + 2, rstd_1)
|
|
87
|
+
|
|
88
|
+
# Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
89
|
+
Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
|
|
90
|
+
|
|
91
|
+
# Store output
|
|
92
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@triton.jit
|
|
96
|
+
def _poly_norm_backward_kernel(
|
|
97
|
+
dY_ptr,
|
|
98
|
+
dY_row_stride,
|
|
99
|
+
dX_ptr,
|
|
100
|
+
dX_row_stride,
|
|
101
|
+
X_ptr,
|
|
102
|
+
X_row_stride,
|
|
103
|
+
W_ptr,
|
|
104
|
+
RSTD_ptr,
|
|
105
|
+
RSTD_row_stride,
|
|
106
|
+
dW_ptr, # shape: (n_programs, 3)
|
|
107
|
+
dW_row_stride,
|
|
108
|
+
dB_ptr, # shape: (n_programs,)
|
|
109
|
+
n_rows,
|
|
110
|
+
n_cols,
|
|
111
|
+
rows_per_program: tl.constexpr,
|
|
112
|
+
BLOCK_SIZE: tl.constexpr,
|
|
113
|
+
):
|
|
114
|
+
"""
|
|
115
|
+
PolyNorm Backward Kernel Gradient:
|
|
116
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
|
117
|
+
|
|
118
|
+
where:
|
|
119
|
+
- D_p = RMS(x^p) = 1/rstd_p
|
|
120
|
+
- S_p = sum(grad * x^p) over the row
|
|
121
|
+
- d = n_cols
|
|
122
|
+
- p ∈ {3, 2, 1}
|
|
123
|
+
"""
|
|
124
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
125
|
+
row_start = row_block_id * rows_per_program
|
|
126
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
127
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
128
|
+
mask = col_offsets < n_cols
|
|
129
|
+
|
|
130
|
+
# Initialize accumulators for weight and bias gradients (scalars)
|
|
131
|
+
dW0_acc = 0.0
|
|
132
|
+
dW1_acc = 0.0
|
|
133
|
+
dW2_acc = 0.0
|
|
134
|
+
dB_acc = 0.0
|
|
135
|
+
|
|
136
|
+
# Load weights
|
|
137
|
+
w0 = tl.load(W_ptr + 0).to(tl.float32)
|
|
138
|
+
w1 = tl.load(W_ptr + 1).to(tl.float32)
|
|
139
|
+
w2 = tl.load(W_ptr + 2).to(tl.float32)
|
|
140
|
+
|
|
141
|
+
dY_ptr += row_start * dY_row_stride
|
|
142
|
+
dX_ptr += row_start * dX_row_stride
|
|
143
|
+
X_ptr += row_start * X_row_stride
|
|
144
|
+
RSTD_ptr += row_start * RSTD_row_stride
|
|
145
|
+
|
|
146
|
+
for _ in range(row_start, row_end):
|
|
147
|
+
# Load input and gradient
|
|
148
|
+
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
149
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
150
|
+
|
|
151
|
+
# Load cached rstd values
|
|
152
|
+
rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
|
|
153
|
+
rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
|
|
154
|
+
rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
|
|
155
|
+
|
|
156
|
+
# Compute powers
|
|
157
|
+
X_pow3 = X_row * X_row * X_row
|
|
158
|
+
X_pow2 = X_row * X_row
|
|
159
|
+
X_pow1 = X_row
|
|
160
|
+
|
|
161
|
+
# Accumulate bias gradient: dB = sum(dY)
|
|
162
|
+
dB_acc += tl.sum(dY_row, axis=0)
|
|
163
|
+
|
|
164
|
+
# Compute gradient w.r.t. input using closed-form formula
|
|
165
|
+
# For p=3: ∂L/∂x from w0 * norm(x³)
|
|
166
|
+
S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
|
|
167
|
+
grad_x_3 = w0 * (
|
|
168
|
+
3.0 * X_pow2 * rstd_3 * dY_row
|
|
169
|
+
- (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# For p=2: ∂L/∂x from w1 * norm(x²)
|
|
173
|
+
S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
|
|
174
|
+
grad_x_2 = w1 * (
|
|
175
|
+
2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# For p=1: ∂L/∂x from w2 * norm(x)
|
|
179
|
+
S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
|
|
180
|
+
grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
|
|
181
|
+
|
|
182
|
+
# Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
|
|
183
|
+
dW0_acc += rstd_3 * S_3
|
|
184
|
+
dW1_acc += rstd_2 * S_2
|
|
185
|
+
dW2_acc += rstd_1 * S_1
|
|
186
|
+
|
|
187
|
+
# Total gradient
|
|
188
|
+
dX_row = grad_x_3 + grad_x_2 + grad_x_1
|
|
189
|
+
|
|
190
|
+
# Store gradient
|
|
191
|
+
tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
|
|
192
|
+
|
|
193
|
+
# Update pointers
|
|
194
|
+
dY_ptr += dY_row_stride
|
|
195
|
+
dX_ptr += dX_row_stride
|
|
196
|
+
X_ptr += X_row_stride
|
|
197
|
+
RSTD_ptr += RSTD_row_stride
|
|
198
|
+
|
|
199
|
+
# Store accumulated gradients (scalars)
|
|
200
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
|
|
201
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
|
|
202
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
|
|
203
|
+
tl.store(dB_ptr + row_block_id, dB_acc)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def poly_norm_forward(X, W, B, eps=1e-6):
|
|
207
|
+
"""
|
|
208
|
+
PolyNorm Forward Pass
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
X: input tensor of shape (*, H) where H is hidden dimension
|
|
212
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
|
213
|
+
B: bias scalar tensor
|
|
214
|
+
eps: epsilon for numerical stability
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Y: output tensor of same shape as X
|
|
218
|
+
X: reshaped input (for backward)
|
|
219
|
+
RSTD: cached rstd values (for backward)
|
|
220
|
+
BLOCK_SIZE: block size used
|
|
221
|
+
num_warps: number of warps used
|
|
222
|
+
"""
|
|
223
|
+
shape = X.shape
|
|
224
|
+
dim = shape[-1]
|
|
225
|
+
X = X.view(-1, dim)
|
|
226
|
+
n_rows, n_cols = X.shape
|
|
227
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
228
|
+
|
|
229
|
+
# RSTD is to cache rstd for each row
|
|
230
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
231
|
+
RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
|
|
232
|
+
|
|
233
|
+
# Check constraints
|
|
234
|
+
assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
|
|
235
|
+
assert B.numel() == 1, "Bias must be a scalar"
|
|
236
|
+
|
|
237
|
+
# XPU-specific optimization
|
|
238
|
+
kernel_args = {}
|
|
239
|
+
if X.device.type == "xpu":
|
|
240
|
+
kernel_args["grf_mode"] = "large"
|
|
241
|
+
|
|
242
|
+
# Launch kernel
|
|
243
|
+
_poly_norm_forward_kernel[(n_rows,)](
|
|
244
|
+
Y,
|
|
245
|
+
Y.stride(0),
|
|
246
|
+
X,
|
|
247
|
+
X.stride(0),
|
|
248
|
+
W,
|
|
249
|
+
B,
|
|
250
|
+
RSTD,
|
|
251
|
+
RSTD.stride(0),
|
|
252
|
+
n_cols,
|
|
253
|
+
eps,
|
|
254
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
255
|
+
num_warps=num_warps,
|
|
256
|
+
**kernel_args,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
263
|
+
"""
|
|
264
|
+
PolyNorm Backward Pass
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
dY: gradient of output
|
|
268
|
+
X: input tensor (already reshaped to 2D)
|
|
269
|
+
W: weight tensor
|
|
270
|
+
RSTD: cached rstd values from forward
|
|
271
|
+
BLOCK_SIZE: block size from forward
|
|
272
|
+
num_warps: number of warps from forward
|
|
273
|
+
in_place: whether to in-place modify dY to store dX (saves memory)
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
dX: gradient w.r.t. input
|
|
277
|
+
dW: gradient w.r.t. weight
|
|
278
|
+
dB: gradient w.r.t. bias
|
|
279
|
+
"""
|
|
280
|
+
shape = dY.shape
|
|
281
|
+
dim = shape[-1]
|
|
282
|
+
dY = dY.view(-1, dim)
|
|
283
|
+
n_rows, n_cols = dY.shape
|
|
284
|
+
|
|
285
|
+
# Get number of SMs for parallelization
|
|
286
|
+
import math
|
|
287
|
+
|
|
288
|
+
sm_count = 1
|
|
289
|
+
if X.device.type == "cuda":
|
|
290
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
291
|
+
elif X.device.type == "xpu":
|
|
292
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
293
|
+
|
|
294
|
+
# Allocate or reuse gradients
|
|
295
|
+
if in_place is True:
|
|
296
|
+
dX = dY
|
|
297
|
+
else:
|
|
298
|
+
dX = torch.zeros_like(dY)
|
|
299
|
+
|
|
300
|
+
_dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
|
|
301
|
+
_dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
|
|
302
|
+
|
|
303
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
304
|
+
grid = (sm_count,)
|
|
305
|
+
|
|
306
|
+
# XPU-specific optimization
|
|
307
|
+
kernel_args = {}
|
|
308
|
+
if X.device.type == "xpu":
|
|
309
|
+
kernel_args["grf_mode"] = "large"
|
|
310
|
+
|
|
311
|
+
# Launch backward kernel
|
|
312
|
+
_poly_norm_backward_kernel[grid](
|
|
313
|
+
dY,
|
|
314
|
+
dY.stride(0),
|
|
315
|
+
dX,
|
|
316
|
+
dX.stride(0),
|
|
317
|
+
X,
|
|
318
|
+
X.stride(0),
|
|
319
|
+
W,
|
|
320
|
+
RSTD,
|
|
321
|
+
RSTD.stride(0),
|
|
322
|
+
_dW,
|
|
323
|
+
_dW.stride(0),
|
|
324
|
+
_dB,
|
|
325
|
+
n_rows,
|
|
326
|
+
n_cols,
|
|
327
|
+
rows_per_program,
|
|
328
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
329
|
+
num_warps=num_warps,
|
|
330
|
+
**kernel_args,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Reduce gradients across SMs
|
|
334
|
+
dX = dX.view(*shape)
|
|
335
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
336
|
+
dB = _dB.sum().to(W.dtype)
|
|
337
|
+
|
|
338
|
+
return dX, dW, dB
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class LigerPolyNormFunction(torch.autograd.Function):
|
|
342
|
+
"""
|
|
343
|
+
PolyNorm Function with forward and backward pass
|
|
344
|
+
|
|
345
|
+
PolyNorm formula:
|
|
346
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
347
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
348
|
+
|
|
349
|
+
Backward uses closed-form gradient:
|
|
350
|
+
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
@staticmethod
|
|
354
|
+
@ensure_contiguous
|
|
355
|
+
def forward(ctx, X, W, B, eps=1e-6, in_place=True):
|
|
356
|
+
"""
|
|
357
|
+
Args:
|
|
358
|
+
X: input tensor of shape (B, T, H) or (BxT, H)
|
|
359
|
+
W: weight tensor of shape (3,) for [w0, w1, w2]
|
|
360
|
+
B: bias scalar
|
|
361
|
+
eps: epsilon for numerical stability
|
|
362
|
+
in_place: whether to in-place modify grad_output in backward (saves memory)
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Y: output tensor of same shape as X
|
|
366
|
+
"""
|
|
367
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
|
|
368
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
369
|
+
ctx.num_warps = num_warps
|
|
370
|
+
ctx.in_place = in_place
|
|
371
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
372
|
+
return Y
|
|
373
|
+
|
|
374
|
+
@staticmethod
|
|
375
|
+
@ensure_contiguous
|
|
376
|
+
def backward(ctx, grad_output):
|
|
377
|
+
"""
|
|
378
|
+
Args:
|
|
379
|
+
grad_output: gradient of output
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
dX, dW, dB: gradients w.r.t. X, W, B
|
|
383
|
+
"""
|
|
384
|
+
X, W, RSTD = ctx.saved_tensors
|
|
385
|
+
dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
|
|
386
|
+
return dX, dW, dB, None, None
|
|
@@ -15,6 +15,7 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
|
|
15
15
|
from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
|
|
16
16
|
from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
|
|
17
17
|
from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
|
|
18
|
+
from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
|
|
18
19
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
19
20
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
20
21
|
from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
|
|
@@ -30,13 +31,16 @@ if TYPE_CHECKING:
|
|
|
30
31
|
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
31
32
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
|
32
33
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
|
34
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
|
|
33
35
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
|
34
36
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
35
37
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
|
|
36
38
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
|
|
37
39
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
|
|
38
40
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
|
|
41
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
|
|
39
42
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
43
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
|
|
40
44
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
41
45
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
|
|
42
46
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
|
|
@@ -51,7 +55,9 @@ if TYPE_CHECKING:
|
|
|
51
55
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
52
56
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
|
|
53
57
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
|
|
58
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
|
|
54
59
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
|
|
60
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
|
|
55
61
|
|
|
56
62
|
|
|
57
63
|
# Check if 'transformers' is installed
|
|
@@ -89,13 +95,16 @@ def __getattr__(name: str):
|
|
|
89
95
|
monkey_patch_symbols = {
|
|
90
96
|
"_apply_liger_kernel",
|
|
91
97
|
"_apply_liger_kernel_to_instance",
|
|
98
|
+
"apply_liger_kernel_to_falcon_h1",
|
|
92
99
|
"apply_liger_kernel_to_gemma",
|
|
93
100
|
"apply_liger_kernel_to_gemma2",
|
|
94
101
|
"apply_liger_kernel_to_gemma3",
|
|
95
102
|
"apply_liger_kernel_to_gemma3_text",
|
|
96
103
|
"apply_liger_kernel_to_glm4",
|
|
97
104
|
"apply_liger_kernel_to_glm4v",
|
|
105
|
+
"apply_liger_kernel_to_glm4v_moe",
|
|
98
106
|
"apply_liger_kernel_to_granite",
|
|
107
|
+
"apply_liger_kernel_to_internvl",
|
|
99
108
|
"apply_liger_kernel_to_llama",
|
|
100
109
|
"apply_liger_kernel_to_llava",
|
|
101
110
|
"apply_liger_kernel_to_llama4",
|
|
@@ -110,7 +119,9 @@ def __getattr__(name: str):
|
|
|
110
119
|
"apply_liger_kernel_to_qwen2_vl",
|
|
111
120
|
"apply_liger_kernel_to_qwen3",
|
|
112
121
|
"apply_liger_kernel_to_qwen3_moe",
|
|
122
|
+
"apply_liger_kernel_to_qwen3_next",
|
|
113
123
|
"apply_liger_kernel_to_smollm3",
|
|
124
|
+
"apply_liger_kernel_to_smolvlm",
|
|
114
125
|
}
|
|
115
126
|
|
|
116
127
|
if name in monkey_patch_symbols:
|
|
@@ -131,6 +142,7 @@ __all__ = [
|
|
|
131
142
|
"LigerJSD",
|
|
132
143
|
"LigerLayerNorm",
|
|
133
144
|
"LigerFusedAddRMSNorm",
|
|
145
|
+
"LigerPolyNorm",
|
|
134
146
|
"LigerRMSNorm",
|
|
135
147
|
"liger_rotary_pos_emb",
|
|
136
148
|
"liger_llama4_text_rotary_pos_emb",
|
|
@@ -153,13 +165,16 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
153
165
|
"AutoLigerKernelForCausalLM",
|
|
154
166
|
"_apply_liger_kernel",
|
|
155
167
|
"_apply_liger_kernel_to_instance",
|
|
168
|
+
"apply_liger_kernel_to_falcon_h1",
|
|
156
169
|
"apply_liger_kernel_to_gemma",
|
|
157
170
|
"apply_liger_kernel_to_gemma2",
|
|
158
171
|
"apply_liger_kernel_to_gemma3",
|
|
159
172
|
"apply_liger_kernel_to_gemma3_text",
|
|
160
173
|
"apply_liger_kernel_to_glm4",
|
|
161
174
|
"apply_liger_kernel_to_glm4v",
|
|
175
|
+
"apply_liger_kernel_to_glm4v_moe",
|
|
162
176
|
"apply_liger_kernel_to_granite",
|
|
177
|
+
"apply_liger_kernel_to_internvl",
|
|
163
178
|
"apply_liger_kernel_to_llama",
|
|
164
179
|
"apply_liger_kernel_to_llava",
|
|
165
180
|
"apply_liger_kernel_to_llama4",
|
|
@@ -174,6 +189,8 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
174
189
|
"apply_liger_kernel_to_qwen2_vl",
|
|
175
190
|
"apply_liger_kernel_to_qwen3",
|
|
176
191
|
"apply_liger_kernel_to_qwen3_moe",
|
|
192
|
+
"apply_liger_kernel_to_qwen3_next",
|
|
177
193
|
"apply_liger_kernel_to_smollm3",
|
|
194
|
+
"apply_liger_kernel_to_smolvlm",
|
|
178
195
|
]
|
|
179
196
|
)
|
|
@@ -12,6 +12,7 @@ from liger_kernel.ops.jsd import LigerJSDFunction
|
|
|
12
12
|
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
|
13
13
|
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
14
14
|
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
|
|
15
|
+
from liger_kernel.ops.poly_norm import LigerPolyNormFunction
|
|
15
16
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
16
17
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
17
18
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
@@ -65,6 +66,7 @@ def liger_fused_linear_cross_entropy(
|
|
|
65
66
|
softcap: Optional[float] = None,
|
|
66
67
|
return_z_loss: bool = False,
|
|
67
68
|
accum_dtype=None,
|
|
69
|
+
use_token_scaling: bool = False,
|
|
68
70
|
):
|
|
69
71
|
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
|
70
72
|
input,
|
|
@@ -79,6 +81,7 @@ def liger_fused_linear_cross_entropy(
|
|
|
79
81
|
softcap,
|
|
80
82
|
return_z_loss,
|
|
81
83
|
accum_dtype,
|
|
84
|
+
use_token_scaling,
|
|
82
85
|
)
|
|
83
86
|
if not return_z_loss:
|
|
84
87
|
return loss
|
|
@@ -256,6 +259,10 @@ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama",
|
|
|
256
259
|
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
|
|
257
260
|
|
|
258
261
|
|
|
262
|
+
def liger_poly_norm(X, W, B, eps=1e-6, in_place=True):
|
|
263
|
+
return LigerPolyNormFunction.apply(X, W, B, eps, in_place)
|
|
264
|
+
|
|
265
|
+
|
|
259
266
|
def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
|
|
260
267
|
return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
|
|
261
268
|
|
|
@@ -16,6 +16,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
16
16
|
softcap: Optional[float] = None,
|
|
17
17
|
return_z_loss: bool = False,
|
|
18
18
|
accum_dtype: Optional[torch.dtype] = None,
|
|
19
|
+
use_token_scaling: bool = False,
|
|
19
20
|
):
|
|
20
21
|
super().__init__()
|
|
21
22
|
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
|
|
@@ -24,7 +25,8 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
24
25
|
assert reduction in {
|
|
25
26
|
"mean",
|
|
26
27
|
"sum",
|
|
27
|
-
|
|
28
|
+
"none",
|
|
29
|
+
}, f"reduction must be 'mean' or 'sum' or 'none'. Got: {reduction}"
|
|
28
30
|
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
29
31
|
self.ce_weight = ce_weight
|
|
30
32
|
self.ignore_index = ignore_index
|
|
@@ -34,6 +36,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
34
36
|
self.softcap = softcap
|
|
35
37
|
self.return_z_loss = return_z_loss
|
|
36
38
|
self.accum_dtype = accum_dtype
|
|
39
|
+
self.use_token_scaling = use_token_scaling
|
|
37
40
|
|
|
38
41
|
def forward(self, lin_weight, _input, target, bias=None):
|
|
39
42
|
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
|
@@ -49,6 +52,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
|
49
52
|
self.softcap,
|
|
50
53
|
self.return_z_loss,
|
|
51
54
|
self.accum_dtype,
|
|
55
|
+
self.use_token_scaling,
|
|
52
56
|
)
|
|
53
57
|
if not self.return_z_loss:
|
|
54
58
|
return loss
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
|
|
11
|
+
|
|
12
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def lce_forward(
|
|
16
|
+
self,
|
|
17
|
+
input_ids: torch.LongTensor = None,
|
|
18
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
19
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
20
|
+
past_key_values: Optional["FalconHybridMambaAttentionDynamicCache"] = None,
|
|
21
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
22
|
+
labels: Optional[torch.LongTensor] = None,
|
|
23
|
+
use_cache: Optional[bool] = None,
|
|
24
|
+
output_attentions: Optional[bool] = None,
|
|
25
|
+
output_hidden_states: Optional[bool] = None,
|
|
26
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
27
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
28
|
+
skip_logits: Optional[bool] = None,
|
|
29
|
+
**kwargs,
|
|
30
|
+
) -> Union[tuple, CausalLMOutputWithPast]:
|
|
31
|
+
r"""
|
|
32
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
33
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
34
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
35
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
>>> from transformers import AutoTokenizer, FalconH1ForCausalLM
|
|
41
|
+
|
|
42
|
+
>>> model = FalconH1ForCausalLM.from_pretrained("...")
|
|
43
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("...")
|
|
44
|
+
|
|
45
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
46
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
47
|
+
|
|
48
|
+
>>> # Generate
|
|
49
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
50
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
51
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
52
|
+
```"""
|
|
53
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
54
|
+
output_hidden_states = (
|
|
55
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
59
|
+
outputs = self.model(
|
|
60
|
+
input_ids=input_ids,
|
|
61
|
+
attention_mask=attention_mask,
|
|
62
|
+
position_ids=position_ids,
|
|
63
|
+
past_key_values=past_key_values,
|
|
64
|
+
inputs_embeds=inputs_embeds,
|
|
65
|
+
use_cache=use_cache,
|
|
66
|
+
output_attentions=output_attentions,
|
|
67
|
+
output_hidden_states=output_hidden_states,
|
|
68
|
+
cache_position=cache_position,
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
hidden_states = outputs[0]
|
|
73
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
74
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
75
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
76
|
+
|
|
77
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
78
|
+
logits = None
|
|
79
|
+
loss = None
|
|
80
|
+
# if in training mode, don't materialize logits
|
|
81
|
+
if skip_logits and labels is None:
|
|
82
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
83
|
+
|
|
84
|
+
if skip_logits is None:
|
|
85
|
+
# By default, if in training mode, don't materialize logits
|
|
86
|
+
skip_logits = self.training and labels is not None
|
|
87
|
+
|
|
88
|
+
if skip_logits:
|
|
89
|
+
loss = LigerForCausalLMLoss(
|
|
90
|
+
hidden_states=kept_hidden_states,
|
|
91
|
+
lm_head_weight=self.lm_head.weight,
|
|
92
|
+
labels=labels,
|
|
93
|
+
shift_labels=shift_labels,
|
|
94
|
+
hidden_size=self.config.hidden_size,
|
|
95
|
+
**kwargs,
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
logits = self.lm_head(kept_hidden_states)
|
|
99
|
+
if labels is not None or shift_labels is not None:
|
|
100
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
101
|
+
|
|
102
|
+
return CausalLMOutputWithPast(
|
|
103
|
+
loss=loss,
|
|
104
|
+
logits=logits,
|
|
105
|
+
past_key_values=outputs.past_key_values,
|
|
106
|
+
hidden_states=outputs.hidden_states,
|
|
107
|
+
attentions=outputs.attentions,
|
|
108
|
+
)
|
|
@@ -228,10 +228,11 @@ def lce_forward(
|
|
|
228
228
|
)
|
|
229
229
|
else:
|
|
230
230
|
logits = self.lm_head(kept_hidden_states)
|
|
231
|
-
if labels is not None:
|
|
231
|
+
if labels is not None or shift_labels is not None:
|
|
232
232
|
loss = self.loss_function(
|
|
233
233
|
logits=logits,
|
|
234
234
|
labels=labels,
|
|
235
|
+
shift_labels=shift_labels,
|
|
235
236
|
vocab_size=self.config.vocab_size,
|
|
236
237
|
**kwargs,
|
|
237
238
|
)
|
|
@@ -252,8 +252,14 @@ def lce_forward(
|
|
|
252
252
|
logits = logits * self.config.final_logit_softcapping
|
|
253
253
|
|
|
254
254
|
loss = None
|
|
255
|
-
if labels is not None:
|
|
256
|
-
loss = self.loss_function(
|
|
255
|
+
if labels is not None or shift_labels is not None:
|
|
256
|
+
loss = self.loss_function(
|
|
257
|
+
logits=logits,
|
|
258
|
+
labels=labels,
|
|
259
|
+
shift_labels=shift_labels,
|
|
260
|
+
vocab_size=self.vocab_size,
|
|
261
|
+
**kwargs,
|
|
262
|
+
)
|
|
257
263
|
|
|
258
264
|
if not return_dict:
|
|
259
265
|
output = (logits,) + outputs[1:]
|