liger-kernel 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/layer_norm.py +126 -89
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +267 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +8 -0
- liger_kernel/transformers/functional.py +67 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +63 -99
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/model/smollm3.py +189 -0
- liger_kernel/transformers/monkey_patch.py +389 -82
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/METADATA +18 -14
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/RECORD +47 -37
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py
CHANGED
|
@@ -40,7 +40,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
|
|
|
40
40
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
41
41
|
tanh_result = tanh(tanh_arg)
|
|
42
42
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
43
|
-
c_row = geglu_a * b_row
|
|
43
|
+
c_row = geglu_a.cast(b_row.dtype) * b_row
|
|
44
44
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
45
45
|
|
|
46
46
|
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import math
|
|
2
1
|
import operator
|
|
3
2
|
|
|
4
3
|
import torch
|
|
@@ -43,30 +42,45 @@ def _layer_norm_forward_kernel(
|
|
|
43
42
|
https://arxiv.org/abs/1607.06450
|
|
44
43
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
45
44
|
"""
|
|
46
|
-
row_idx = tl.program_id(0)
|
|
45
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
47
46
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
48
47
|
mask = col_offsets < n_cols
|
|
49
48
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
49
|
+
# Pre-load weights and bias in fp32 to avoid repeated conversions
|
|
50
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
51
|
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
|
|
52
|
+
W_f32 = W_row.to(tl.float32)
|
|
53
|
+
B_f32 = B_row.to(tl.float32)
|
|
54
|
+
|
|
55
|
+
# Calculate pointers for this row
|
|
56
|
+
row_X_ptr = X_ptr + row_idx * X_row_stride
|
|
57
|
+
row_Y_ptr = Y_ptr + row_idx * Y_row_stride
|
|
58
|
+
row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
|
|
59
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
|
|
60
|
+
|
|
61
|
+
# Load input data and convert to fp32 for numerical stability
|
|
62
|
+
X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
|
|
63
|
+
X_f32 = X_row.to(tl.float32)
|
|
64
|
+
|
|
65
|
+
# Compute statistics in fp32 for numerical stability
|
|
66
|
+
n_cols_f32 = n_cols.to(tl.float32)
|
|
67
|
+
mean = tl.sum(X_f32, axis=0) / n_cols_f32
|
|
68
|
+
X_centered = X_f32 - mean
|
|
69
|
+
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
70
|
+
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
71
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
|
|
62
72
|
rstd = rsqrt(var + eps)
|
|
63
73
|
|
|
64
|
-
|
|
65
|
-
tl.store(
|
|
74
|
+
# Store statistics (convert back to original dtype only once)
|
|
75
|
+
tl.store(row_Mean_ptr, mean.to(X_row.dtype))
|
|
76
|
+
tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
|
|
66
77
|
|
|
67
|
-
|
|
78
|
+
# Fused normalization and affine transformation
|
|
79
|
+
# Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
|
|
80
|
+
Y_f32 = X_centered * rstd * W_f32 + B_f32
|
|
68
81
|
|
|
69
|
-
|
|
82
|
+
# Store output (single conversion back to original dtype)
|
|
83
|
+
tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
|
|
70
84
|
|
|
71
85
|
|
|
72
86
|
@triton.jit
|
|
@@ -81,73 +95,87 @@ def _layer_norm_backward_kernel(
|
|
|
81
95
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
82
96
|
stride_x, # stride of each row in input
|
|
83
97
|
stride_dx, # stride of each row in input grad
|
|
84
|
-
stride_dw, # stride of each row in weights grad
|
|
85
|
-
stride_db, # stride of each row in bias grad
|
|
86
98
|
stride_dy, # stride of each row in output grad
|
|
87
|
-
n_rows,
|
|
88
99
|
n_cols,
|
|
89
|
-
rows_per_program: tl.constexpr,
|
|
90
100
|
BLOCK_SIZE: tl.constexpr,
|
|
91
101
|
dtype: tl.constexpr,
|
|
102
|
+
atomic_dtype: tl.constexpr,
|
|
92
103
|
):
|
|
93
104
|
"""
|
|
94
105
|
References:
|
|
95
106
|
https://arxiv.org/abs/1607.06450
|
|
96
107
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
97
|
-
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
98
|
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
|
99
108
|
"""
|
|
100
|
-
|
|
101
|
-
row_start = row_block_id * rows_per_program
|
|
102
|
-
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
109
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
103
110
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
104
111
|
mask = cols < n_cols
|
|
105
112
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
tl.store(
|
|
139
|
-
|
|
113
|
+
# Pre-load weights once (same optimization as forward pass)
|
|
114
|
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
115
|
+
w_f32 = w.to(tl.float32)
|
|
116
|
+
n_cols_f32 = n_cols.to(tl.float32)
|
|
117
|
+
|
|
118
|
+
# Calculate pointers for this specific row
|
|
119
|
+
row_X_ptr = X_ptr + row_idx * stride_x
|
|
120
|
+
row_DX_ptr = DX_ptr + row_idx * stride_dx
|
|
121
|
+
row_DY_ptr = DY_ptr + row_idx * stride_dy
|
|
122
|
+
row_Mean_ptr = Mean_ptr + row_idx
|
|
123
|
+
row_RSTD_ptr = RSTD_ptr + row_idx
|
|
124
|
+
|
|
125
|
+
# Load data for this row
|
|
126
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
127
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
128
|
+
mean = tl.load(row_Mean_ptr)
|
|
129
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
130
|
+
|
|
131
|
+
# Convert to fp32 for numerical stability
|
|
132
|
+
x_f32 = x.to(tl.float32)
|
|
133
|
+
dy_f32 = dy.to(tl.float32)
|
|
134
|
+
mean_f32 = mean.to(tl.float32)
|
|
135
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
136
|
+
|
|
137
|
+
# Compute backward pass for this row
|
|
138
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
139
|
+
wdy = w_f32 * dy_f32
|
|
140
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
|
|
141
|
+
c2 = tl.sum(wdy, axis=0) / n_cols_f32
|
|
142
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
143
|
+
|
|
144
|
+
# Store input gradient
|
|
145
|
+
tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
|
|
146
|
+
|
|
147
|
+
# Accumulate weight and bias gradients using atomic operations
|
|
148
|
+
dw = dy_f32 * x_hat
|
|
149
|
+
db = dy_f32
|
|
150
|
+
tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
|
|
151
|
+
tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
|
|
140
152
|
|
|
141
153
|
|
|
142
154
|
def layer_norm_forward(X, W, B, eps):
|
|
155
|
+
"""
|
|
156
|
+
Args:
|
|
157
|
+
X: Input tensor of shape (..., hidden_size)
|
|
158
|
+
W: Weight tensor of shape (hidden_size,)
|
|
159
|
+
B: Bias tensor of shape (hidden_size,)
|
|
160
|
+
eps: Small constant for numerical stability
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Tuple of (output, input, mean, rstd, block_size, num_warps)
|
|
164
|
+
"""
|
|
143
165
|
shape = X.shape
|
|
144
166
|
dim = shape[-1]
|
|
145
167
|
X = X.view(-1, dim)
|
|
146
168
|
n_rows, n_cols = X.shape
|
|
169
|
+
|
|
170
|
+
# Calculate optimal block size and warp configuration
|
|
147
171
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
172
|
+
|
|
173
|
+
# Allocate output tensors
|
|
148
174
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
149
175
|
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
150
176
|
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
177
|
+
|
|
178
|
+
# Validate input dimensions
|
|
151
179
|
if X.shape[1] != W.shape[0]:
|
|
152
180
|
raise ValueError(
|
|
153
181
|
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
@@ -159,7 +187,9 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
159
187
|
if X.device.type == "xpu":
|
|
160
188
|
kernel_args["grf_mode"] = "large"
|
|
161
189
|
|
|
162
|
-
|
|
190
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
191
|
+
grid = (n_rows,)
|
|
192
|
+
_layer_norm_forward_kernel[grid](
|
|
163
193
|
Y,
|
|
164
194
|
Y.stride(0),
|
|
165
195
|
X,
|
|
@@ -176,35 +206,43 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
176
206
|
eps,
|
|
177
207
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
178
208
|
num_warps=num_warps,
|
|
179
|
-
**kernel_args,
|
|
209
|
+
**kernel_args,
|
|
180
210
|
)
|
|
211
|
+
|
|
181
212
|
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
|
182
213
|
|
|
183
214
|
|
|
184
215
|
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
216
|
+
"""
|
|
217
|
+
Args:
|
|
218
|
+
dY: Gradient of output
|
|
219
|
+
X: Input tensor
|
|
220
|
+
W: Weight tensor
|
|
221
|
+
B: Bias tensor
|
|
222
|
+
Mean: Pre-computed mean
|
|
223
|
+
RSTD: Pre-computed reciprocal standard deviation
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Tuple of (input_grad, weight_grad, bias_grad)
|
|
227
|
+
"""
|
|
185
228
|
shape = dY.shape
|
|
186
229
|
dim = shape[-1]
|
|
187
230
|
dY = dY.view(-1, dim)
|
|
188
231
|
n_rows, n_cols = dY.shape
|
|
189
232
|
|
|
190
|
-
|
|
191
|
-
if X.device.type == "cuda":
|
|
192
|
-
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
193
|
-
elif X.device.type == "xpu":
|
|
194
|
-
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
195
|
-
|
|
233
|
+
# Allocate gradient tensors
|
|
196
234
|
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
197
|
-
|
|
198
|
-
|
|
235
|
+
# Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
|
|
236
|
+
grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
|
|
237
|
+
DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
|
|
238
|
+
DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
|
|
199
239
|
|
|
240
|
+
# Calculate optimal block size and warp configuration
|
|
200
241
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
201
242
|
if n_cols > BLOCK_SIZE:
|
|
202
|
-
raise RuntimeError(
|
|
203
|
-
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
204
|
-
)
|
|
243
|
+
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
205
244
|
|
|
206
|
-
|
|
207
|
-
grid = (sm_count,)
|
|
245
|
+
# Determine dtype for triton operations
|
|
208
246
|
triton_dtype = (
|
|
209
247
|
tl.float32
|
|
210
248
|
if X.dtype == torch.float32
|
|
@@ -212,41 +250,40 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
212
250
|
if X.dtype == torch.bfloat16
|
|
213
251
|
else tl.float16
|
|
214
252
|
if X.dtype == torch.float16
|
|
215
|
-
else tl.float32 # fallback
|
|
253
|
+
else tl.float32 # fallback
|
|
216
254
|
)
|
|
217
255
|
|
|
256
|
+
# Use float32 for atomic operations if bfloat16 is not supported
|
|
257
|
+
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
258
|
+
|
|
259
|
+
kernel_args = {"num_warps": num_warps}
|
|
218
260
|
# XPU-specific optimization
|
|
219
|
-
kernel_args = {}
|
|
220
261
|
if X.device.type == "xpu":
|
|
221
262
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
222
263
|
|
|
264
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
265
|
+
grid = (n_rows,)
|
|
223
266
|
_layer_norm_backward_kernel[grid](
|
|
224
267
|
X,
|
|
225
268
|
W,
|
|
226
269
|
Mean,
|
|
227
270
|
RSTD,
|
|
228
271
|
DX,
|
|
229
|
-
|
|
230
|
-
|
|
272
|
+
DW,
|
|
273
|
+
DB,
|
|
231
274
|
dY,
|
|
232
275
|
X.stride(0),
|
|
233
276
|
DX.stride(0),
|
|
234
|
-
_DW.stride(0),
|
|
235
|
-
_DB.stride(0),
|
|
236
277
|
dY.stride(0),
|
|
237
|
-
n_rows,
|
|
238
278
|
n_cols,
|
|
239
|
-
rows_per_program,
|
|
240
279
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
241
280
|
dtype=triton_dtype,
|
|
242
|
-
|
|
281
|
+
atomic_dtype=atomic_dtype,
|
|
282
|
+
**kernel_args,
|
|
243
283
|
)
|
|
244
284
|
|
|
245
|
-
DW = _DW.sum(dim=0).to(W.dtype)
|
|
246
|
-
DB = _DB.sum(dim=0).to(W.dtype)
|
|
247
|
-
|
|
248
285
|
DX = DX.view(*shape)
|
|
249
|
-
return DX, DW, DB
|
|
286
|
+
return DX, DW.to(W.dtype), DB.to(W.dtype)
|
|
250
287
|
|
|
251
288
|
|
|
252
289
|
class LigerLayerNormFunction(torch.autograd.Function):
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
import triton
|
|
4
|
+
import triton.language as tl
|
|
5
|
+
|
|
6
|
+
from torch.nn.modules.utils import _pair
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.softmax import _softmax_forward
|
|
9
|
+
from liger_kernel.ops.sparsemax import _sparsemax_backward
|
|
10
|
+
from liger_kernel.ops.sparsemax import _sparsemax_forward
|
|
11
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
12
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@triton.jit
|
|
16
|
+
def _mask_fwd_kernel(
|
|
17
|
+
scores_ptr,
|
|
18
|
+
out_ptr,
|
|
19
|
+
stride_b,
|
|
20
|
+
stride_m,
|
|
21
|
+
stride_n,
|
|
22
|
+
L,
|
|
23
|
+
mask_val: tl.constexpr,
|
|
24
|
+
BLOCK: tl.constexpr,
|
|
25
|
+
num_warps: tl.constexpr,
|
|
26
|
+
):
|
|
27
|
+
row_block = tl.program_id(0)
|
|
28
|
+
col_block = tl.program_id(1)
|
|
29
|
+
batch_id = tl.program_id(2)
|
|
30
|
+
|
|
31
|
+
row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
|
|
32
|
+
col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
|
|
33
|
+
in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
|
|
34
|
+
|
|
35
|
+
base = scores_ptr + batch_id * stride_b
|
|
36
|
+
offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
|
|
37
|
+
future = col_idx[None, :] > row_idx[:, None]
|
|
38
|
+
mask_load = in_bounds & ~future
|
|
39
|
+
out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca")
|
|
40
|
+
tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@triton.jit
|
|
44
|
+
def _mask_bwd_kernel(
|
|
45
|
+
grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr
|
|
46
|
+
):
|
|
47
|
+
row_block = tl.program_id(0)
|
|
48
|
+
col_block = tl.program_id(1)
|
|
49
|
+
batch_id = tl.program_id(2)
|
|
50
|
+
|
|
51
|
+
row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
|
|
52
|
+
col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
|
|
53
|
+
in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
|
|
54
|
+
|
|
55
|
+
base = grad_in_ptr + batch_id * stride_b
|
|
56
|
+
offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
|
|
57
|
+
grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca")
|
|
58
|
+
|
|
59
|
+
future = col_idx[None, :] > row_idx[:, None]
|
|
60
|
+
zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype)
|
|
61
|
+
out = tl.where(future, zero, grad_vals)
|
|
62
|
+
|
|
63
|
+
tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
*batch, L, _ = scores.shape
|
|
68
|
+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
|
|
69
|
+
scores_f = scores.view(N, L, L)
|
|
70
|
+
out = torch.empty_like(scores_f)
|
|
71
|
+
|
|
72
|
+
sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
|
|
73
|
+
BLOCK_SIZE, num_warps = calculate_settings(L)
|
|
74
|
+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
|
|
75
|
+
_mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps)
|
|
76
|
+
return out.view(*batch, L, L)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
*batch, L, _ = grad.shape
|
|
81
|
+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
|
|
82
|
+
grad_f = grad.view(N, L, L)
|
|
83
|
+
out = torch.empty_like(grad_f)
|
|
84
|
+
|
|
85
|
+
sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
|
|
86
|
+
BLOCK_SIZE, num_warps = calculate_settings(L)
|
|
87
|
+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
|
|
88
|
+
_mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
|
|
89
|
+
return out.view(*batch, L, L)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor:
|
|
93
|
+
*batch, L, _ = scores.shape
|
|
94
|
+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
|
|
95
|
+
scores_f = scores.view(N, L, L)
|
|
96
|
+
out = torch.empty_like(scores_f)
|
|
97
|
+
|
|
98
|
+
sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
|
|
99
|
+
BLOCK_SIZE, num_warps = calculate_settings(L)
|
|
100
|
+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
|
|
101
|
+
_mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps)
|
|
102
|
+
return out.view(*batch, L, L)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor:
|
|
106
|
+
*batch, L, _ = grad.shape
|
|
107
|
+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
|
|
108
|
+
grad_f = grad.view(N, L, L)
|
|
109
|
+
out = torch.empty_like(grad_f)
|
|
110
|
+
|
|
111
|
+
sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
|
|
112
|
+
BLOCK_SIZE, num_warps = calculate_settings(L)
|
|
113
|
+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
|
|
114
|
+
_mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
|
|
115
|
+
return out.view(*batch, L, L)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class LigerMultiTokenAttentionFunction(torch.autograd.Function):
|
|
119
|
+
@staticmethod
|
|
120
|
+
@ensure_contiguous
|
|
121
|
+
def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False):
|
|
122
|
+
scores_inf = _mask_inf_forward(scores)
|
|
123
|
+
|
|
124
|
+
out_flat_sparse = None
|
|
125
|
+
activation_output = None
|
|
126
|
+
|
|
127
|
+
ctx.sparse = sparse
|
|
128
|
+
|
|
129
|
+
if sparse:
|
|
130
|
+
if scores_inf.dtype != torch.float32:
|
|
131
|
+
raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores")
|
|
132
|
+
probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1)
|
|
133
|
+
activation_output = probs_sparse
|
|
134
|
+
ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias)
|
|
135
|
+
ctx.out_flat_sparse_saved = True
|
|
136
|
+
else:
|
|
137
|
+
probs_softmax, _, _, _ = _softmax_forward(scores_inf)
|
|
138
|
+
activation_output = probs_softmax
|
|
139
|
+
ctx.save_for_backward(scores_inf, activation_output, weight, bias)
|
|
140
|
+
ctx.out_flat_sparse_saved = False
|
|
141
|
+
|
|
142
|
+
out_conv = F.conv2d(
|
|
143
|
+
activation_output,
|
|
144
|
+
weight,
|
|
145
|
+
bias,
|
|
146
|
+
stride=stride,
|
|
147
|
+
padding=padding,
|
|
148
|
+
dilation=dilation,
|
|
149
|
+
groups=groups,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
out = _mask_zero_forward(out_conv)
|
|
153
|
+
|
|
154
|
+
ctx.stride = _pair(stride)
|
|
155
|
+
ctx.padding = _pair(padding)
|
|
156
|
+
ctx.dilation = _pair(dilation)
|
|
157
|
+
ctx.groups = groups
|
|
158
|
+
ctx.dim = -1
|
|
159
|
+
|
|
160
|
+
return out
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
@ensure_contiguous
|
|
164
|
+
def backward(ctx, grad_out):
|
|
165
|
+
if ctx.out_flat_sparse_saved:
|
|
166
|
+
scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors
|
|
167
|
+
else:
|
|
168
|
+
scores_inf, activation_output, weight, bias = ctx.saved_tensors
|
|
169
|
+
out_flat_sparse = None
|
|
170
|
+
|
|
171
|
+
use_sparsemax = ctx.sparse
|
|
172
|
+
dim = ctx.dim
|
|
173
|
+
stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
|
|
174
|
+
|
|
175
|
+
grad_conv = _mask_zero_backward(grad_out)
|
|
176
|
+
|
|
177
|
+
grad_probs = F.conv_transpose2d(
|
|
178
|
+
grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
grad_weight = torch.nn.grad.conv2d_weight(
|
|
182
|
+
input=activation_output,
|
|
183
|
+
weight_size=weight.shape,
|
|
184
|
+
grad_output=grad_conv,
|
|
185
|
+
stride=stride,
|
|
186
|
+
padding=padding,
|
|
187
|
+
dilation=dilation,
|
|
188
|
+
groups=groups,
|
|
189
|
+
)
|
|
190
|
+
grad_bias = None
|
|
191
|
+
if bias is not None:
|
|
192
|
+
grad_bias = grad_conv.sum(dim=(0, 2, 3))
|
|
193
|
+
|
|
194
|
+
grad_scores_inf = None
|
|
195
|
+
if use_sparsemax:
|
|
196
|
+
if not ctx.out_flat_sparse_saved or out_flat_sparse is None:
|
|
197
|
+
raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.")
|
|
198
|
+
grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim)
|
|
199
|
+
else:
|
|
200
|
+
grad_probs_cont = grad_probs
|
|
201
|
+
probs_cont = activation_output
|
|
202
|
+
dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True)
|
|
203
|
+
grad_scores_inf = probs_cont * (grad_probs_cont - dot)
|
|
204
|
+
|
|
205
|
+
grad_scores = _mask_inf_backward(grad_scores_inf)
|
|
206
|
+
|
|
207
|
+
return (grad_scores, grad_weight, grad_bias, None, None, None, None, None)
|