liger-kernel-nightly 0.5.10.dev20250524022630__py3-none-any.whl → 0.5.10.dev20250526154149__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/transformers/functional.py +34 -0
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel_nightly-0.5.10.dev20250524022630.dist-info → liger_kernel_nightly-0.5.10.dev20250526154149.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250524022630.dist-info → liger_kernel_nightly-0.5.10.dev20250526154149.dist-info}/RECORD +12 -8
- {liger_kernel_nightly-0.5.10.dev20250524022630.dist-info → liger_kernel_nightly-0.5.10.dev20250526154149.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250524022630.dist-info → liger_kernel_nightly-0.5.10.dev20250526154149.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250524022630.dist-info → liger_kernel_nightly-0.5.10.dev20250526154149.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250524022630.dist-info → liger_kernel_nightly-0.5.10.dev20250526154149.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,207 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
import triton
|
4
|
+
import triton.language as tl
|
5
|
+
|
6
|
+
from torch.nn.modules.utils import _pair
|
7
|
+
|
8
|
+
from liger_kernel.ops.softmax import _softmax_forward
|
9
|
+
from liger_kernel.ops.sparsemax import _sparsemax_backward
|
10
|
+
from liger_kernel.ops.sparsemax import _sparsemax_forward
|
11
|
+
from liger_kernel.ops.utils import calculate_settings
|
12
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
13
|
+
|
14
|
+
|
15
|
+
@triton.jit
|
16
|
+
def _mask_fwd_kernel(
|
17
|
+
scores_ptr,
|
18
|
+
out_ptr,
|
19
|
+
stride_b,
|
20
|
+
stride_m,
|
21
|
+
stride_n,
|
22
|
+
L,
|
23
|
+
mask_val: tl.constexpr,
|
24
|
+
BLOCK: tl.constexpr,
|
25
|
+
num_warps: tl.constexpr,
|
26
|
+
):
|
27
|
+
row_block = tl.program_id(0)
|
28
|
+
col_block = tl.program_id(1)
|
29
|
+
batch_id = tl.program_id(2)
|
30
|
+
|
31
|
+
row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
|
32
|
+
col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
|
33
|
+
in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
|
34
|
+
|
35
|
+
base = scores_ptr + batch_id * stride_b
|
36
|
+
offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
|
37
|
+
future = col_idx[None, :] > row_idx[:, None]
|
38
|
+
mask_load = in_bounds & ~future
|
39
|
+
out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca")
|
40
|
+
tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs")
|
41
|
+
|
42
|
+
|
43
|
+
@triton.jit
|
44
|
+
def _mask_bwd_kernel(
|
45
|
+
grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr
|
46
|
+
):
|
47
|
+
row_block = tl.program_id(0)
|
48
|
+
col_block = tl.program_id(1)
|
49
|
+
batch_id = tl.program_id(2)
|
50
|
+
|
51
|
+
row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
|
52
|
+
col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
|
53
|
+
in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
|
54
|
+
|
55
|
+
base = grad_in_ptr + batch_id * stride_b
|
56
|
+
offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
|
57
|
+
grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca")
|
58
|
+
|
59
|
+
future = col_idx[None, :] > row_idx[:, None]
|
60
|
+
zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype)
|
61
|
+
out = tl.where(future, zero, grad_vals)
|
62
|
+
|
63
|
+
tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb")
|
64
|
+
|
65
|
+
|
66
|
+
def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor:
|
67
|
+
*batch, L, _ = scores.shape
|
68
|
+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
|
69
|
+
scores_f = scores.view(N, L, L)
|
70
|
+
out = torch.empty_like(scores_f)
|
71
|
+
|
72
|
+
sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
|
73
|
+
BLOCK_SIZE, num_warps = calculate_settings(L)
|
74
|
+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
|
75
|
+
_mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps)
|
76
|
+
return out.view(*batch, L, L)
|
77
|
+
|
78
|
+
|
79
|
+
def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor:
|
80
|
+
*batch, L, _ = grad.shape
|
81
|
+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
|
82
|
+
grad_f = grad.view(N, L, L)
|
83
|
+
out = torch.empty_like(grad_f)
|
84
|
+
|
85
|
+
sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
|
86
|
+
BLOCK_SIZE, num_warps = calculate_settings(L)
|
87
|
+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
|
88
|
+
_mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
|
89
|
+
return out.view(*batch, L, L)
|
90
|
+
|
91
|
+
|
92
|
+
def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor:
|
93
|
+
*batch, L, _ = scores.shape
|
94
|
+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
|
95
|
+
scores_f = scores.view(N, L, L)
|
96
|
+
out = torch.empty_like(scores_f)
|
97
|
+
|
98
|
+
sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
|
99
|
+
BLOCK_SIZE, num_warps = calculate_settings(L)
|
100
|
+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
|
101
|
+
_mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps)
|
102
|
+
return out.view(*batch, L, L)
|
103
|
+
|
104
|
+
|
105
|
+
def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor:
|
106
|
+
*batch, L, _ = grad.shape
|
107
|
+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
|
108
|
+
grad_f = grad.view(N, L, L)
|
109
|
+
out = torch.empty_like(grad_f)
|
110
|
+
|
111
|
+
sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
|
112
|
+
BLOCK_SIZE, num_warps = calculate_settings(L)
|
113
|
+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
|
114
|
+
_mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
|
115
|
+
return out.view(*batch, L, L)
|
116
|
+
|
117
|
+
|
118
|
+
class LigerMultiTokenAttentionFunction(torch.autograd.Function):
|
119
|
+
@staticmethod
|
120
|
+
@ensure_contiguous
|
121
|
+
def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False):
|
122
|
+
scores_inf = _mask_inf_forward(scores)
|
123
|
+
|
124
|
+
out_flat_sparse = None
|
125
|
+
activation_output = None
|
126
|
+
|
127
|
+
ctx.sparse = sparse
|
128
|
+
|
129
|
+
if sparse:
|
130
|
+
if scores_inf.dtype != torch.float32:
|
131
|
+
raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores")
|
132
|
+
probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1)
|
133
|
+
activation_output = probs_sparse
|
134
|
+
ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias)
|
135
|
+
ctx.out_flat_sparse_saved = True
|
136
|
+
else:
|
137
|
+
probs_softmax, _, _, _ = _softmax_forward(scores_inf)
|
138
|
+
activation_output = probs_softmax
|
139
|
+
ctx.save_for_backward(scores_inf, activation_output, weight, bias)
|
140
|
+
ctx.out_flat_sparse_saved = False
|
141
|
+
|
142
|
+
out_conv = F.conv2d(
|
143
|
+
activation_output,
|
144
|
+
weight,
|
145
|
+
bias,
|
146
|
+
stride=stride,
|
147
|
+
padding=padding,
|
148
|
+
dilation=dilation,
|
149
|
+
groups=groups,
|
150
|
+
)
|
151
|
+
|
152
|
+
out = _mask_zero_forward(out_conv)
|
153
|
+
|
154
|
+
ctx.stride = _pair(stride)
|
155
|
+
ctx.padding = _pair(padding)
|
156
|
+
ctx.dilation = _pair(dilation)
|
157
|
+
ctx.groups = groups
|
158
|
+
ctx.dim = -1
|
159
|
+
|
160
|
+
return out
|
161
|
+
|
162
|
+
@staticmethod
|
163
|
+
@ensure_contiguous
|
164
|
+
def backward(ctx, grad_out):
|
165
|
+
if ctx.out_flat_sparse_saved:
|
166
|
+
scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors
|
167
|
+
else:
|
168
|
+
scores_inf, activation_output, weight, bias = ctx.saved_tensors
|
169
|
+
out_flat_sparse = None
|
170
|
+
|
171
|
+
use_sparsemax = ctx.sparse
|
172
|
+
dim = ctx.dim
|
173
|
+
stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
|
174
|
+
|
175
|
+
grad_conv = _mask_zero_backward(grad_out)
|
176
|
+
|
177
|
+
grad_probs = F.conv_transpose2d(
|
178
|
+
grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups
|
179
|
+
)
|
180
|
+
|
181
|
+
grad_weight = torch.nn.grad.conv2d_weight(
|
182
|
+
input=activation_output,
|
183
|
+
weight_size=weight.shape,
|
184
|
+
grad_output=grad_conv,
|
185
|
+
stride=stride,
|
186
|
+
padding=padding,
|
187
|
+
dilation=dilation,
|
188
|
+
groups=groups,
|
189
|
+
)
|
190
|
+
grad_bias = None
|
191
|
+
if bias is not None:
|
192
|
+
grad_bias = grad_conv.sum(dim=(0, 2, 3))
|
193
|
+
|
194
|
+
grad_scores_inf = None
|
195
|
+
if use_sparsemax:
|
196
|
+
if not ctx.out_flat_sparse_saved or out_flat_sparse is None:
|
197
|
+
raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.")
|
198
|
+
grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim)
|
199
|
+
else:
|
200
|
+
grad_probs_cont = grad_probs
|
201
|
+
probs_cont = activation_output
|
202
|
+
dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True)
|
203
|
+
grad_scores_inf = probs_cont * (grad_probs_cont - dot)
|
204
|
+
|
205
|
+
grad_scores = _mask_inf_backward(grad_scores_inf)
|
206
|
+
|
207
|
+
return (grad_scores, grad_weight, grad_bias, None, None, None, None, None)
|
@@ -0,0 +1,201 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
9
|
+
|
10
|
+
|
11
|
+
@triton.jit
|
12
|
+
def _softmax_single_block_forward_kernel(
|
13
|
+
Y_ptr,
|
14
|
+
Y_row_stride,
|
15
|
+
X_ptr,
|
16
|
+
X_row_stride,
|
17
|
+
n_cols,
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
19
|
+
):
|
20
|
+
row_id = tl.program_id(0)
|
21
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
22
|
+
mask = offs < n_cols
|
23
|
+
|
24
|
+
x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
25
|
+
m = tl.max(x, axis=0)
|
26
|
+
e = tl.exp(x - m)
|
27
|
+
d = tl.sum(e, axis=0)
|
28
|
+
y = e / d
|
29
|
+
tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
|
30
|
+
|
31
|
+
|
32
|
+
@triton.jit
|
33
|
+
def _softmax_multi_block_forward_kernel(
|
34
|
+
Y_ptr,
|
35
|
+
Y_row_stride,
|
36
|
+
X_ptr,
|
37
|
+
X_row_stride,
|
38
|
+
n_cols,
|
39
|
+
BLOCK_SIZE: tl.constexpr,
|
40
|
+
):
|
41
|
+
row_id = tl.program_id(0)
|
42
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
43
|
+
|
44
|
+
m = tl.float32(-float("inf"))
|
45
|
+
d = tl.float32(0.0)
|
46
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
47
|
+
idx = start + offs
|
48
|
+
mask = idx < n_cols
|
49
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
50
|
+
blk_max = tl.max(xblk, axis=0)
|
51
|
+
new_m = tl.max(m, blk_max)
|
52
|
+
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
|
53
|
+
m = new_m
|
54
|
+
|
55
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
56
|
+
idx = start + offs
|
57
|
+
mask = idx < n_cols
|
58
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
59
|
+
yblk = tl.exp(xblk - m) / d
|
60
|
+
tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
|
61
|
+
|
62
|
+
|
63
|
+
@triton.jit
|
64
|
+
def _softmax_single_block_backward_kernel(
|
65
|
+
dy_ptr,
|
66
|
+
dy_stride,
|
67
|
+
y_ptr,
|
68
|
+
y_stride,
|
69
|
+
dx_ptr,
|
70
|
+
dx_stride,
|
71
|
+
n_cols,
|
72
|
+
BLOCK_SIZE: tl.constexpr,
|
73
|
+
):
|
74
|
+
row_id = tl.program_id(0)
|
75
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
76
|
+
mask = offs < n_cols
|
77
|
+
|
78
|
+
dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
|
79
|
+
y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
|
80
|
+
dot = tl.sum(dy * y, axis=0)
|
81
|
+
dx = y * (dy - dot)
|
82
|
+
tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
|
83
|
+
|
84
|
+
|
85
|
+
@triton.jit
|
86
|
+
def _softmax_multi_block_backward_kernel(
|
87
|
+
dy_ptr,
|
88
|
+
dy_stride,
|
89
|
+
y_ptr,
|
90
|
+
y_stride,
|
91
|
+
dx_ptr,
|
92
|
+
dx_stride,
|
93
|
+
n_cols,
|
94
|
+
BLOCK_SIZE: tl.constexpr,
|
95
|
+
):
|
96
|
+
row_id = tl.program_id(0)
|
97
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
98
|
+
acc = tl.float32(0.0)
|
99
|
+
|
100
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
101
|
+
idx = start + offs
|
102
|
+
mask = idx < n_cols
|
103
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
104
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
105
|
+
acc += tl.sum(dy_blk * y_blk, axis=0)
|
106
|
+
|
107
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
108
|
+
idx = start + offs
|
109
|
+
mask = idx < n_cols
|
110
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
111
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
112
|
+
dx_blk = y_blk * (dy_blk - acc)
|
113
|
+
tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
|
114
|
+
|
115
|
+
|
116
|
+
def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
|
117
|
+
*batch, n_cols = x.shape
|
118
|
+
x2d = x.contiguous().view(-1, n_cols)
|
119
|
+
n_rows = x2d.shape[0]
|
120
|
+
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
122
|
+
y2d = torch.empty_like(x2d)
|
123
|
+
|
124
|
+
if n_cols <= BLOCK_SIZE:
|
125
|
+
_softmax_single_block_forward_kernel[(n_rows,)](
|
126
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
127
|
+
)
|
128
|
+
multi_block_launch = False
|
129
|
+
else:
|
130
|
+
_softmax_multi_block_forward_kernel[(n_rows,)](
|
131
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
132
|
+
)
|
133
|
+
multi_block_launch = True
|
134
|
+
|
135
|
+
return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
|
136
|
+
|
137
|
+
|
138
|
+
def _softmax_backward(
|
139
|
+
dy: torch.Tensor,
|
140
|
+
y: torch.Tensor,
|
141
|
+
BLOCK_SIZE: int,
|
142
|
+
num_warps: int,
|
143
|
+
multi_block_launch: bool,
|
144
|
+
) -> torch.Tensor:
|
145
|
+
*batch, n_cols = dy.shape
|
146
|
+
dy2d = dy.contiguous().view(-1, n_cols)
|
147
|
+
y2d = y.contiguous().view(-1, n_cols)
|
148
|
+
n_rows = dy2d.shape[0]
|
149
|
+
dx2d = torch.empty_like(dy2d)
|
150
|
+
|
151
|
+
if not multi_block_launch and n_cols <= BLOCK_SIZE:
|
152
|
+
_softmax_single_block_backward_kernel[(n_rows,)](
|
153
|
+
dy2d,
|
154
|
+
dy2d.stride(0),
|
155
|
+
y2d,
|
156
|
+
y2d.stride(0),
|
157
|
+
dx2d,
|
158
|
+
dx2d.stride(0),
|
159
|
+
n_cols,
|
160
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
161
|
+
num_warps=num_warps,
|
162
|
+
)
|
163
|
+
else:
|
164
|
+
_softmax_multi_block_backward_kernel[(n_rows,)](
|
165
|
+
dy2d,
|
166
|
+
dy2d.stride(0),
|
167
|
+
y2d,
|
168
|
+
y2d.stride(0),
|
169
|
+
dx2d,
|
170
|
+
dx2d.stride(0),
|
171
|
+
n_cols,
|
172
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
173
|
+
num_warps=num_warps,
|
174
|
+
)
|
175
|
+
|
176
|
+
return dx2d.view(*batch, n_cols)
|
177
|
+
|
178
|
+
|
179
|
+
class LigerSoftmaxFunction(torch.autograd.Function):
|
180
|
+
@staticmethod
|
181
|
+
@ensure_contiguous
|
182
|
+
def forward(ctx, input_: torch.Tensor):
|
183
|
+
y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
|
184
|
+
ctx.save_for_backward(y)
|
185
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
186
|
+
ctx.num_warps = num_warps
|
187
|
+
ctx.multi_block_launch = multi_block_launch
|
188
|
+
return y
|
189
|
+
|
190
|
+
@staticmethod
|
191
|
+
@ensure_contiguous
|
192
|
+
def backward(ctx, grad_output):
|
193
|
+
(y,) = ctx.saved_tensors
|
194
|
+
dx = _softmax_backward(
|
195
|
+
grad_output,
|
196
|
+
y,
|
197
|
+
ctx.BLOCK_SIZE,
|
198
|
+
ctx.num_warps,
|
199
|
+
ctx.multi_block_launch,
|
200
|
+
)
|
201
|
+
return dx
|
liger_kernel/ops/sparsemax.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
1
3
|
import torch
|
2
4
|
import triton
|
3
5
|
import triton.language as tl
|
@@ -105,63 +107,73 @@ def _sparsemax_backward_kernel(
|
|
105
107
|
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
|
106
108
|
|
107
109
|
|
110
|
+
def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
111
|
+
if dim < 0:
|
112
|
+
dim += x.dim()
|
113
|
+
x_sw = x.transpose(dim, -1).contiguous()
|
114
|
+
n_cols = x_sw.size(-1)
|
115
|
+
n_rows = x_sw.numel() // n_cols
|
116
|
+
x_flat = x_sw.view(n_rows, n_cols)
|
117
|
+
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
118
|
+
|
119
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
120
|
+
out_flat = torch.empty_like(x_flat)
|
121
|
+
grid = (n_rows,)
|
122
|
+
_sparsemax_forward_kernel[grid](
|
123
|
+
x_flat,
|
124
|
+
x_flat.stride(0),
|
125
|
+
x_sorted_flat,
|
126
|
+
x_sorted_flat.stride(0),
|
127
|
+
out_flat,
|
128
|
+
out_flat.stride(0),
|
129
|
+
n_cols,
|
130
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
131
|
+
num_warps=num_warps,
|
132
|
+
)
|
133
|
+
|
134
|
+
y = out_flat.view_as(x_sw).transpose(dim, -1)
|
135
|
+
return y, out_flat
|
136
|
+
|
137
|
+
|
138
|
+
def _sparsemax_backward(
|
139
|
+
grad_out: torch.Tensor,
|
140
|
+
out_flat: torch.Tensor,
|
141
|
+
dim: int,
|
142
|
+
) -> torch.Tensor:
|
143
|
+
grad_sw = grad_out.transpose(dim, -1).contiguous()
|
144
|
+
n_cols = grad_sw.size(-1)
|
145
|
+
n_rows = grad_sw.numel() // n_cols
|
146
|
+
go_flat = grad_sw.view(n_rows, n_cols)
|
147
|
+
|
148
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
149
|
+
dx_flat = torch.empty_like(go_flat)
|
150
|
+
grid = (n_rows,)
|
151
|
+
_sparsemax_backward_kernel[grid](
|
152
|
+
out_flat,
|
153
|
+
go_flat,
|
154
|
+
dx_flat,
|
155
|
+
out_flat.stride(0),
|
156
|
+
n_cols,
|
157
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
158
|
+
num_warps=num_warps,
|
159
|
+
)
|
160
|
+
|
161
|
+
dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
|
162
|
+
return dx
|
163
|
+
|
164
|
+
|
108
165
|
class LigerSparsemaxFunction(torch.autograd.Function):
|
109
166
|
@staticmethod
|
110
167
|
@ensure_contiguous
|
111
168
|
def forward(ctx, x: torch.Tensor, dim: int):
|
112
|
-
|
113
|
-
dim += x.dim()
|
114
|
-
ctx.dim = dim
|
115
|
-
|
116
|
-
x_sw = x.transpose(dim, -1).contiguous()
|
117
|
-
n_cols = x_sw.size(-1)
|
118
|
-
n_rows = x_sw.numel() // n_cols
|
119
|
-
x_flat = x_sw.view(n_rows, n_cols)
|
120
|
-
|
121
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
122
|
-
out_flat = torch.empty_like(x_flat)
|
123
|
-
grid = (n_rows,)
|
124
|
-
|
125
|
-
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
126
|
-
|
127
|
-
_sparsemax_forward_kernel[grid](
|
128
|
-
x_flat,
|
129
|
-
x_flat.stride(0),
|
130
|
-
x_sorted_flat,
|
131
|
-
x_sorted_flat.stride(0),
|
132
|
-
out_flat,
|
133
|
-
out_flat.stride(0),
|
134
|
-
n_cols,
|
135
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
136
|
-
num_warps=num_warps,
|
137
|
-
)
|
138
|
-
|
169
|
+
y, out_flat = _sparsemax_forward(x, dim)
|
139
170
|
ctx.save_for_backward(out_flat)
|
140
|
-
|
171
|
+
ctx.dim = dim
|
172
|
+
return y
|
141
173
|
|
142
174
|
@staticmethod
|
143
175
|
@ensure_contiguous
|
144
176
|
def backward(ctx, grad_out: torch.Tensor):
|
145
177
|
(out_flat,) = ctx.saved_tensors
|
146
|
-
|
147
|
-
|
148
|
-
go_sw = grad_out.transpose(dim, -1).contiguous()
|
149
|
-
n_cols = go_sw.size(-1)
|
150
|
-
n_rows = go_sw.numel() // n_cols
|
151
|
-
go_flat = go_sw.view(n_rows, n_cols)
|
152
|
-
|
153
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
154
|
-
gi_flat = torch.empty_like(go_flat)
|
155
|
-
grid = (n_rows,)
|
156
|
-
|
157
|
-
_sparsemax_backward_kernel[grid](
|
158
|
-
out_flat,
|
159
|
-
go_flat,
|
160
|
-
gi_flat,
|
161
|
-
out_flat.stride(0),
|
162
|
-
n_cols,
|
163
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
164
|
-
num_warps=num_warps,
|
165
|
-
)
|
166
|
-
|
167
|
-
return gi_flat.view_as(go_sw).transpose(dim, -1), None
|
178
|
+
dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
|
179
|
+
return dx, None
|
@@ -9,9 +9,11 @@ from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
9
9
|
from liger_kernel.ops.jsd import LigerJSDFunction
|
10
10
|
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
11
11
|
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
12
|
+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
|
12
13
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
13
14
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
14
15
|
from liger_kernel.ops.rope import LigerRopeFunction
|
16
|
+
from liger_kernel.ops.softmax import LigerSoftmaxFunction
|
15
17
|
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
16
18
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
17
19
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
@@ -167,6 +169,34 @@ def liger_sparsemax(
|
|
167
169
|
return LigerSparsemaxFunction.apply(input, dim)
|
168
170
|
|
169
171
|
|
172
|
+
def liger_multi_token_attention(
|
173
|
+
scores,
|
174
|
+
weight,
|
175
|
+
bias=None,
|
176
|
+
stride: int = 1,
|
177
|
+
padding: int = 0,
|
178
|
+
dilation: int = 1,
|
179
|
+
groups: int = 1,
|
180
|
+
sparse: bool = False,
|
181
|
+
):
|
182
|
+
"""
|
183
|
+
Functional interface for multi-token attention.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
scores: Input tensor of shape (B, C_in, L, L)
|
187
|
+
weight: Convolution weight tensor of shape (C_out, C_in // groups, K, K)
|
188
|
+
bias: Optional bias tensor of shape (C_out,)
|
189
|
+
stride: Stride for the convolution (default: 1)
|
190
|
+
padding: Padding for the convolution (default: 0)
|
191
|
+
dilation: Dilation factor for the convolution (default: 1)
|
192
|
+
groups: Number of groups for the convolution (default: 1)
|
193
|
+
sparse: Specifies if input tensors are expected to be sparse (default: False)
|
194
|
+
Returns:
|
195
|
+
Output tensor after applying multi-token attention.
|
196
|
+
"""
|
197
|
+
return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
|
198
|
+
|
199
|
+
|
170
200
|
def liger_tvd(
|
171
201
|
input,
|
172
202
|
target,
|
@@ -203,5 +233,9 @@ def liger_swiglu(a, b):
|
|
203
233
|
return LigerSiLUMulFunction.apply(a, b)
|
204
234
|
|
205
235
|
|
236
|
+
def liger_softmax(x):
|
237
|
+
return LigerSoftmaxFunction.apply(x)
|
238
|
+
|
239
|
+
|
206
240
|
def liger_dyt(x, alpha, gamma, beta):
|
207
241
|
return LigerDyTFunction.apply(x, alpha, gamma, beta)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
|
6
|
+
from torch.nn.modules.utils import _pair
|
7
|
+
|
8
|
+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
|
9
|
+
|
10
|
+
|
11
|
+
class LigerMultiTokenAttention(nn.Module):
|
12
|
+
"""
|
13
|
+
Multi-Token Attention:
|
14
|
+
out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
|
15
|
+
|
16
|
+
Reference: https://arxiv.org/pdf/2504.00927
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
in_channels: int,
|
22
|
+
out_channels: int,
|
23
|
+
kernel_size: int,
|
24
|
+
stride: int = 1,
|
25
|
+
padding: int = 0,
|
26
|
+
dilation: int = 1,
|
27
|
+
groups: int = 1,
|
28
|
+
bias: bool = True,
|
29
|
+
sparse: bool = False,
|
30
|
+
):
|
31
|
+
super().__init__()
|
32
|
+
self.in_channels = in_channels
|
33
|
+
self.out_channels = out_channels
|
34
|
+
self.kernel_size = _pair(kernel_size)
|
35
|
+
self.stride = _pair(stride)
|
36
|
+
self.padding = _pair(padding)
|
37
|
+
self.dilation = _pair(dilation)
|
38
|
+
self.groups = groups
|
39
|
+
self.sparse = sparse
|
40
|
+
|
41
|
+
self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size))
|
42
|
+
if bias:
|
43
|
+
self.bias = nn.Parameter(torch.empty(out_channels))
|
44
|
+
else:
|
45
|
+
self.register_parameter("bias", None)
|
46
|
+
|
47
|
+
self.reset_parameters()
|
48
|
+
|
49
|
+
def reset_parameters(self):
|
50
|
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
51
|
+
if self.bias is not None:
|
52
|
+
nn.init.zeros_(self.bias)
|
53
|
+
|
54
|
+
def forward(self, scores: torch.Tensor) -> torch.Tensor:
|
55
|
+
return LigerMultiTokenAttentionFunction.apply(
|
56
|
+
scores,
|
57
|
+
self.weight,
|
58
|
+
self.bias,
|
59
|
+
self.stride,
|
60
|
+
self.padding,
|
61
|
+
self.dilation,
|
62
|
+
self.groups,
|
63
|
+
self.sparse,
|
64
|
+
)
|
@@ -0,0 +1,12 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
from liger_kernel.ops.softmax import LigerSoftmaxFunction
|
5
|
+
|
6
|
+
|
7
|
+
class LigerKernelSoftmax(nn.Module):
|
8
|
+
def __init__(self):
|
9
|
+
super().__init__()
|
10
|
+
|
11
|
+
def forward(self, x: torch.Tensor):
|
12
|
+
return LigerSoftmaxFunction.apply(x)
|
@@ -26,10 +26,12 @@ liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0
|
|
26
26
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
27
27
|
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
28
28
|
liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
|
29
|
+
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
29
30
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
30
31
|
liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
|
31
32
|
liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
|
32
|
-
liger_kernel/ops/
|
33
|
+
liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
|
34
|
+
liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
|
33
35
|
liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
|
34
36
|
liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
35
37
|
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
@@ -40,7 +42,7 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
|
|
40
42
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
41
43
|
liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
|
42
44
|
liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
|
43
|
-
liger_kernel/transformers/functional.py,sha256=
|
45
|
+
liger_kernel/transformers/functional.py,sha256=QmnAFpRgIbp9Rzlfp8QibwiEbf5BUcANxfY68an7o8c,6444
|
44
46
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
|
45
47
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
46
48
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
@@ -51,9 +53,11 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
|
|
51
53
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
52
54
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
53
55
|
liger_kernel/transformers/monkey_patch.py,sha256=DKv5-4KyXLiVhAJ9WVFv1I1i1DzjaudTrhqx6EVYViU,74505
|
56
|
+
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
54
57
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
55
58
|
liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
|
56
59
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
60
|
+
liger_kernel/transformers/softmax.py,sha256=u7bFo35-cjaAm9of6-DLzmkaNFELOM-9AgyrcvUPifw,270
|
57
61
|
liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
|
58
62
|
liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
|
59
63
|
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
@@ -82,9 +86,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
82
86
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
83
87
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
84
88
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
85
|
-
liger_kernel_nightly-0.5.10.
|
86
|
-
liger_kernel_nightly-0.5.10.
|
87
|
-
liger_kernel_nightly-0.5.10.
|
88
|
-
liger_kernel_nightly-0.5.10.
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
89
|
+
liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/METADATA,sha256=0CXMJx6ef3SurofjUlAWwMaj-prwFf-xg_nMo-n6UPE,24113
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|