liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/functional.py +62 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +62 -98
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/monkey_patch.py +304 -70
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.jit
|
|
12
|
+
def _softmax_single_block_forward_kernel(
|
|
13
|
+
Y_ptr,
|
|
14
|
+
Y_row_stride,
|
|
15
|
+
X_ptr,
|
|
16
|
+
X_row_stride,
|
|
17
|
+
n_cols,
|
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
|
19
|
+
):
|
|
20
|
+
row_id = tl.program_id(0)
|
|
21
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
22
|
+
mask = offs < n_cols
|
|
23
|
+
|
|
24
|
+
x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
25
|
+
m = tl.max(x, axis=0)
|
|
26
|
+
e = tl.exp(x - m)
|
|
27
|
+
d = tl.sum(e, axis=0)
|
|
28
|
+
y = e / d
|
|
29
|
+
tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@triton.jit
|
|
33
|
+
def _softmax_multi_block_forward_kernel(
|
|
34
|
+
Y_ptr,
|
|
35
|
+
Y_row_stride,
|
|
36
|
+
X_ptr,
|
|
37
|
+
X_row_stride,
|
|
38
|
+
n_cols,
|
|
39
|
+
BLOCK_SIZE: tl.constexpr,
|
|
40
|
+
):
|
|
41
|
+
row_id = tl.program_id(0)
|
|
42
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
43
|
+
|
|
44
|
+
m = tl.float32(-float("inf"))
|
|
45
|
+
d = tl.float32(0.0)
|
|
46
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
47
|
+
idx = start + offs
|
|
48
|
+
mask = idx < n_cols
|
|
49
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
50
|
+
blk_max = tl.max(xblk, axis=0)
|
|
51
|
+
new_m = tl.max(m, blk_max)
|
|
52
|
+
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
|
|
53
|
+
m = new_m
|
|
54
|
+
|
|
55
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
56
|
+
idx = start + offs
|
|
57
|
+
mask = idx < n_cols
|
|
58
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
59
|
+
yblk = tl.exp(xblk - m) / d
|
|
60
|
+
tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@triton.jit
|
|
64
|
+
def _softmax_single_block_backward_kernel(
|
|
65
|
+
dy_ptr,
|
|
66
|
+
dy_stride,
|
|
67
|
+
y_ptr,
|
|
68
|
+
y_stride,
|
|
69
|
+
dx_ptr,
|
|
70
|
+
dx_stride,
|
|
71
|
+
n_cols,
|
|
72
|
+
BLOCK_SIZE: tl.constexpr,
|
|
73
|
+
):
|
|
74
|
+
row_id = tl.program_id(0)
|
|
75
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
76
|
+
mask = offs < n_cols
|
|
77
|
+
|
|
78
|
+
dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
|
|
79
|
+
y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
|
|
80
|
+
dot = tl.sum(dy * y, axis=0)
|
|
81
|
+
dx = y * (dy - dot)
|
|
82
|
+
tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@triton.jit
|
|
86
|
+
def _softmax_multi_block_backward_kernel(
|
|
87
|
+
dy_ptr,
|
|
88
|
+
dy_stride,
|
|
89
|
+
y_ptr,
|
|
90
|
+
y_stride,
|
|
91
|
+
dx_ptr,
|
|
92
|
+
dx_stride,
|
|
93
|
+
n_cols,
|
|
94
|
+
BLOCK_SIZE: tl.constexpr,
|
|
95
|
+
):
|
|
96
|
+
row_id = tl.program_id(0)
|
|
97
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
98
|
+
acc = tl.float32(0.0)
|
|
99
|
+
|
|
100
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
101
|
+
idx = start + offs
|
|
102
|
+
mask = idx < n_cols
|
|
103
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
|
104
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
|
105
|
+
acc += tl.sum(dy_blk * y_blk, axis=0)
|
|
106
|
+
|
|
107
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
108
|
+
idx = start + offs
|
|
109
|
+
mask = idx < n_cols
|
|
110
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
|
111
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
|
112
|
+
dx_blk = y_blk * (dy_blk - acc)
|
|
113
|
+
tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
|
|
117
|
+
*batch, n_cols = x.shape
|
|
118
|
+
x2d = x.contiguous().view(-1, n_cols)
|
|
119
|
+
n_rows = x2d.shape[0]
|
|
120
|
+
|
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
122
|
+
y2d = torch.empty_like(x2d)
|
|
123
|
+
|
|
124
|
+
if n_cols <= BLOCK_SIZE:
|
|
125
|
+
_softmax_single_block_forward_kernel[(n_rows,)](
|
|
126
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
127
|
+
)
|
|
128
|
+
multi_block_launch = False
|
|
129
|
+
else:
|
|
130
|
+
_softmax_multi_block_forward_kernel[(n_rows,)](
|
|
131
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
132
|
+
)
|
|
133
|
+
multi_block_launch = True
|
|
134
|
+
|
|
135
|
+
return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _softmax_backward(
|
|
139
|
+
dy: torch.Tensor,
|
|
140
|
+
y: torch.Tensor,
|
|
141
|
+
BLOCK_SIZE: int,
|
|
142
|
+
num_warps: int,
|
|
143
|
+
multi_block_launch: bool,
|
|
144
|
+
) -> torch.Tensor:
|
|
145
|
+
*batch, n_cols = dy.shape
|
|
146
|
+
dy2d = dy.contiguous().view(-1, n_cols)
|
|
147
|
+
y2d = y.contiguous().view(-1, n_cols)
|
|
148
|
+
n_rows = dy2d.shape[0]
|
|
149
|
+
dx2d = torch.empty_like(dy2d)
|
|
150
|
+
|
|
151
|
+
if not multi_block_launch and n_cols <= BLOCK_SIZE:
|
|
152
|
+
_softmax_single_block_backward_kernel[(n_rows,)](
|
|
153
|
+
dy2d,
|
|
154
|
+
dy2d.stride(0),
|
|
155
|
+
y2d,
|
|
156
|
+
y2d.stride(0),
|
|
157
|
+
dx2d,
|
|
158
|
+
dx2d.stride(0),
|
|
159
|
+
n_cols,
|
|
160
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
161
|
+
num_warps=num_warps,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
_softmax_multi_block_backward_kernel[(n_rows,)](
|
|
165
|
+
dy2d,
|
|
166
|
+
dy2d.stride(0),
|
|
167
|
+
y2d,
|
|
168
|
+
y2d.stride(0),
|
|
169
|
+
dx2d,
|
|
170
|
+
dx2d.stride(0),
|
|
171
|
+
n_cols,
|
|
172
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
173
|
+
num_warps=num_warps,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return dx2d.view(*batch, n_cols)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class LigerSoftmaxFunction(torch.autograd.Function):
|
|
180
|
+
@staticmethod
|
|
181
|
+
@ensure_contiguous
|
|
182
|
+
def forward(ctx, input_: torch.Tensor):
|
|
183
|
+
y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
|
|
184
|
+
ctx.save_for_backward(y)
|
|
185
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
186
|
+
ctx.num_warps = num_warps
|
|
187
|
+
ctx.multi_block_launch = multi_block_launch
|
|
188
|
+
return y
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
@ensure_contiguous
|
|
192
|
+
def backward(ctx, grad_output):
|
|
193
|
+
(y,) = ctx.saved_tensors
|
|
194
|
+
dx = _softmax_backward(
|
|
195
|
+
grad_output,
|
|
196
|
+
y,
|
|
197
|
+
ctx.BLOCK_SIZE,
|
|
198
|
+
ctx.num_warps,
|
|
199
|
+
ctx.multi_block_launch,
|
|
200
|
+
)
|
|
201
|
+
return dx
|
liger_kernel/ops/sparsemax.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
import triton.language as tl
|
|
@@ -105,63 +107,73 @@ def _sparsemax_backward_kernel(
|
|
|
105
107
|
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
|
|
106
108
|
|
|
107
109
|
|
|
110
|
+
def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
111
|
+
if dim < 0:
|
|
112
|
+
dim += x.dim()
|
|
113
|
+
x_sw = x.transpose(dim, -1).contiguous()
|
|
114
|
+
n_cols = x_sw.size(-1)
|
|
115
|
+
n_rows = x_sw.numel() // n_cols
|
|
116
|
+
x_flat = x_sw.view(n_rows, n_cols)
|
|
117
|
+
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
|
118
|
+
|
|
119
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
120
|
+
out_flat = torch.empty_like(x_flat)
|
|
121
|
+
grid = (n_rows,)
|
|
122
|
+
_sparsemax_forward_kernel[grid](
|
|
123
|
+
x_flat,
|
|
124
|
+
x_flat.stride(0),
|
|
125
|
+
x_sorted_flat,
|
|
126
|
+
x_sorted_flat.stride(0),
|
|
127
|
+
out_flat,
|
|
128
|
+
out_flat.stride(0),
|
|
129
|
+
n_cols,
|
|
130
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
131
|
+
num_warps=num_warps,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
y = out_flat.view_as(x_sw).transpose(dim, -1)
|
|
135
|
+
return y, out_flat
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _sparsemax_backward(
|
|
139
|
+
grad_out: torch.Tensor,
|
|
140
|
+
out_flat: torch.Tensor,
|
|
141
|
+
dim: int,
|
|
142
|
+
) -> torch.Tensor:
|
|
143
|
+
grad_sw = grad_out.transpose(dim, -1).contiguous()
|
|
144
|
+
n_cols = grad_sw.size(-1)
|
|
145
|
+
n_rows = grad_sw.numel() // n_cols
|
|
146
|
+
go_flat = grad_sw.view(n_rows, n_cols)
|
|
147
|
+
|
|
148
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
149
|
+
dx_flat = torch.empty_like(go_flat)
|
|
150
|
+
grid = (n_rows,)
|
|
151
|
+
_sparsemax_backward_kernel[grid](
|
|
152
|
+
out_flat,
|
|
153
|
+
go_flat,
|
|
154
|
+
dx_flat,
|
|
155
|
+
out_flat.stride(0),
|
|
156
|
+
n_cols,
|
|
157
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
158
|
+
num_warps=num_warps,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
|
|
162
|
+
return dx
|
|
163
|
+
|
|
164
|
+
|
|
108
165
|
class LigerSparsemaxFunction(torch.autograd.Function):
|
|
109
166
|
@staticmethod
|
|
110
167
|
@ensure_contiguous
|
|
111
168
|
def forward(ctx, x: torch.Tensor, dim: int):
|
|
112
|
-
|
|
113
|
-
dim += x.dim()
|
|
114
|
-
ctx.dim = dim
|
|
115
|
-
|
|
116
|
-
x_sw = x.transpose(dim, -1).contiguous()
|
|
117
|
-
n_cols = x_sw.size(-1)
|
|
118
|
-
n_rows = x_sw.numel() // n_cols
|
|
119
|
-
x_flat = x_sw.view(n_rows, n_cols)
|
|
120
|
-
|
|
121
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
122
|
-
out_flat = torch.empty_like(x_flat)
|
|
123
|
-
grid = (n_rows,)
|
|
124
|
-
|
|
125
|
-
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
|
126
|
-
|
|
127
|
-
_sparsemax_forward_kernel[grid](
|
|
128
|
-
x_flat,
|
|
129
|
-
x_flat.stride(0),
|
|
130
|
-
x_sorted_flat,
|
|
131
|
-
x_sorted_flat.stride(0),
|
|
132
|
-
out_flat,
|
|
133
|
-
out_flat.stride(0),
|
|
134
|
-
n_cols,
|
|
135
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
136
|
-
num_warps=num_warps,
|
|
137
|
-
)
|
|
138
|
-
|
|
169
|
+
y, out_flat = _sparsemax_forward(x, dim)
|
|
139
170
|
ctx.save_for_backward(out_flat)
|
|
140
|
-
|
|
171
|
+
ctx.dim = dim
|
|
172
|
+
return y
|
|
141
173
|
|
|
142
174
|
@staticmethod
|
|
143
175
|
@ensure_contiguous
|
|
144
176
|
def backward(ctx, grad_out: torch.Tensor):
|
|
145
177
|
(out_flat,) = ctx.saved_tensors
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
go_sw = grad_out.transpose(dim, -1).contiguous()
|
|
149
|
-
n_cols = go_sw.size(-1)
|
|
150
|
-
n_rows = go_sw.numel() // n_cols
|
|
151
|
-
go_flat = go_sw.view(n_rows, n_cols)
|
|
152
|
-
|
|
153
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
154
|
-
gi_flat = torch.empty_like(go_flat)
|
|
155
|
-
grid = (n_rows,)
|
|
156
|
-
|
|
157
|
-
_sparsemax_backward_kernel[grid](
|
|
158
|
-
out_flat,
|
|
159
|
-
go_flat,
|
|
160
|
-
gi_flat,
|
|
161
|
-
out_flat.stride(0),
|
|
162
|
-
n_cols,
|
|
163
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
164
|
-
num_warps=num_warps,
|
|
165
|
-
)
|
|
166
|
-
|
|
167
|
-
return gi_flat.view_as(go_sw).transpose(dim, -1), None
|
|
178
|
+
dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
|
|
179
|
+
return dx, None
|
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
|
|
|
26
26
|
# sigmoid requires type float32
|
|
27
27
|
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
28
28
|
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
29
|
-
c_row = silu(a_row) * b_row
|
|
29
|
+
c_row = silu(a_row).cast(b_row.dtype) * b_row
|
|
30
30
|
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
|
31
31
|
|
|
32
32
|
|
|
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
|
|
|
30
30
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
|
|
31
31
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
32
32
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
33
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
|
|
33
34
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
|
|
34
35
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
|
35
36
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
|
@@ -87,6 +88,7 @@ def __getattr__(name: str):
|
|
|
87
88
|
"apply_liger_kernel_to_granite",
|
|
88
89
|
"apply_liger_kernel_to_llama",
|
|
89
90
|
"apply_liger_kernel_to_llava",
|
|
91
|
+
"apply_liger_kernel_to_llama4",
|
|
90
92
|
"apply_liger_kernel_to_mistral",
|
|
91
93
|
"apply_liger_kernel_to_mixtral",
|
|
92
94
|
"apply_liger_kernel_to_mllama",
|
|
@@ -141,6 +143,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
141
143
|
"apply_liger_kernel_to_granite",
|
|
142
144
|
"apply_liger_kernel_to_llama",
|
|
143
145
|
"apply_liger_kernel_to_llava",
|
|
146
|
+
"apply_liger_kernel_to_llama4",
|
|
144
147
|
"apply_liger_kernel_to_mistral",
|
|
145
148
|
"apply_liger_kernel_to_mixtral",
|
|
146
149
|
"apply_liger_kernel_to_mllama",
|
|
@@ -4,14 +4,17 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
|
4
4
|
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
5
5
|
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
6
6
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
7
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
|
|
7
8
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
8
9
|
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
9
10
|
from liger_kernel.ops.jsd import LigerJSDFunction
|
|
10
11
|
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
|
11
12
|
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
13
|
+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
|
|
12
14
|
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
13
15
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
14
16
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
17
|
+
from liger_kernel.ops.softmax import LigerSoftmaxFunction
|
|
15
18
|
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
|
16
19
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
17
20
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
@@ -167,6 +170,61 @@ def liger_sparsemax(
|
|
|
167
170
|
return LigerSparsemaxFunction.apply(input, dim)
|
|
168
171
|
|
|
169
172
|
|
|
173
|
+
def liger_multi_token_attention(
|
|
174
|
+
scores,
|
|
175
|
+
weight,
|
|
176
|
+
bias=None,
|
|
177
|
+
stride: int = 1,
|
|
178
|
+
padding: int = 0,
|
|
179
|
+
dilation: int = 1,
|
|
180
|
+
groups: int = 1,
|
|
181
|
+
sparse: bool = False,
|
|
182
|
+
):
|
|
183
|
+
"""
|
|
184
|
+
Functional interface for multi-token attention.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
scores: Input tensor of shape (B, C_in, L, L)
|
|
188
|
+
weight: Convolution weight tensor of shape (C_out, C_in // groups, K, K)
|
|
189
|
+
bias: Optional bias tensor of shape (C_out,)
|
|
190
|
+
stride: Stride for the convolution (default: 1)
|
|
191
|
+
padding: Padding for the convolution (default: 0)
|
|
192
|
+
dilation: Dilation factor for the convolution (default: 1)
|
|
193
|
+
groups: Number of groups for the convolution (default: 1)
|
|
194
|
+
sparse: Specifies if input tensors are expected to be sparse (default: False)
|
|
195
|
+
Returns:
|
|
196
|
+
Output tensor after applying multi-token attention.
|
|
197
|
+
"""
|
|
198
|
+
return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def liger_fused_neighborhood_attention(
|
|
202
|
+
query,
|
|
203
|
+
key,
|
|
204
|
+
value,
|
|
205
|
+
kernel_size: int = 7,
|
|
206
|
+
dilation: int = 1,
|
|
207
|
+
scale: float = None,
|
|
208
|
+
):
|
|
209
|
+
"""
|
|
210
|
+
Liger fused neighborhood attention.
|
|
211
|
+
|
|
212
|
+
paper: https://arxiv.org/pdf/2504.16922
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
|
216
|
+
key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
|
217
|
+
value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
|
218
|
+
kernel_size: Size of the neighborhood window (default: 7)
|
|
219
|
+
dilation: Dilation factor for the neighborhood (default: 1)
|
|
220
|
+
scale: Scaling factor for attention scores (default: rsqrt(head_dim))
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Output tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
|
224
|
+
"""
|
|
225
|
+
return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale)
|
|
226
|
+
|
|
227
|
+
|
|
170
228
|
def liger_tvd(
|
|
171
229
|
input,
|
|
172
230
|
target,
|
|
@@ -203,5 +261,9 @@ def liger_swiglu(a, b):
|
|
|
203
261
|
return LigerSiLUMulFunction.apply(a, b)
|
|
204
262
|
|
|
205
263
|
|
|
264
|
+
def liger_softmax(x):
|
|
265
|
+
return LigerSoftmaxFunction.apply(x)
|
|
266
|
+
|
|
267
|
+
|
|
206
268
|
def liger_dyt(x, alpha, gamma, beta):
|
|
207
269
|
return LigerDyTFunction.apply(x, alpha, gamma, beta)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LigerFusedNeighborhoodAttention(nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
Liger Fused Neighborhood Attention Module.
|
|
14
|
+
|
|
15
|
+
Paper: https://arxiv.org/pdf/2504.16922
|
|
16
|
+
|
|
17
|
+
Fused Neighborhood attention restricts the attention mechanism to a local neighborhood
|
|
18
|
+
around each position, reducing computational complexity from O(n²) to O(n*k)
|
|
19
|
+
where k is the neighborhood size.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
hidden_size (int): The hidden dimension size
|
|
23
|
+
num_heads (int): Number of attention heads
|
|
24
|
+
kernel_size (int): Size of the neighborhood window (default: 7)
|
|
25
|
+
dilation (int): Dilation factor for the neighborhood (default: 1)
|
|
26
|
+
bias (bool): Whether to use bias in linear projections (default: True)
|
|
27
|
+
dropout (float): Dropout probability (default: 0.0)
|
|
28
|
+
scale (Optional[float]): Scaling factor for attention scores.
|
|
29
|
+
If None, uses 1/sqrt(head_dim) (default: None)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
hidden_size: int,
|
|
35
|
+
num_heads: int,
|
|
36
|
+
kernel_size: int = 7,
|
|
37
|
+
dilation: int = 1,
|
|
38
|
+
bias: bool = True,
|
|
39
|
+
dropout: float = 0.0,
|
|
40
|
+
scale: Optional[float] = None,
|
|
41
|
+
):
|
|
42
|
+
super().__init__()
|
|
43
|
+
|
|
44
|
+
if hidden_size % num_heads != 0:
|
|
45
|
+
raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
|
|
46
|
+
|
|
47
|
+
if kernel_size <= 0:
|
|
48
|
+
raise ValueError(f"kernel_size ({kernel_size}) must be positive")
|
|
49
|
+
|
|
50
|
+
if kernel_size % 2 == 0:
|
|
51
|
+
raise ValueError(f"kernel_size ({kernel_size}) must be odd")
|
|
52
|
+
|
|
53
|
+
if dilation < 1:
|
|
54
|
+
raise ValueError(f"dilation ({dilation}) must be positive")
|
|
55
|
+
|
|
56
|
+
self.hidden_size = hidden_size
|
|
57
|
+
self.num_heads = num_heads
|
|
58
|
+
self.head_dim = hidden_size // num_heads
|
|
59
|
+
self.kernel_size = kernel_size
|
|
60
|
+
self.dilation = dilation
|
|
61
|
+
self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
|
|
62
|
+
self.dropout_p = dropout
|
|
63
|
+
|
|
64
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
65
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
66
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
67
|
+
|
|
68
|
+
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
|
69
|
+
|
|
70
|
+
if dropout > 0.0:
|
|
71
|
+
self.dropout = nn.Dropout(dropout)
|
|
72
|
+
else:
|
|
73
|
+
self.dropout = None
|
|
74
|
+
|
|
75
|
+
def forward(
|
|
76
|
+
self,
|
|
77
|
+
hidden_states: torch.Tensor,
|
|
78
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
79
|
+
) -> torch.Tensor:
|
|
80
|
+
"""
|
|
81
|
+
Forward pass of the fused neighborhood attention module.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
|
85
|
+
attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
|
|
89
|
+
"""
|
|
90
|
+
if attention_mask is not None:
|
|
91
|
+
raise NotImplementedError("Attention mask is not yet supported in LigerFusedNeighborhoodAttention")
|
|
92
|
+
|
|
93
|
+
batch_size, seq_len, hidden_size = hidden_states.shape
|
|
94
|
+
|
|
95
|
+
query = self.q_proj(hidden_states)
|
|
96
|
+
key = self.k_proj(hidden_states)
|
|
97
|
+
value = self.v_proj(hidden_states)
|
|
98
|
+
|
|
99
|
+
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
100
|
+
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
101
|
+
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
102
|
+
|
|
103
|
+
attn_output = LigerFusedNeighborhoodAttentionFunction.apply(
|
|
104
|
+
query, key, value, self.kernel_size, self.dilation, self.scale
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
|
|
108
|
+
|
|
109
|
+
if self.dropout is not None:
|
|
110
|
+
attn_output = self.dropout(attn_output)
|
|
111
|
+
|
|
112
|
+
output = self.out_proj(attn_output)
|
|
113
|
+
|
|
114
|
+
return output
|
|
115
|
+
|
|
116
|
+
def extra_repr(self) -> str:
|
|
117
|
+
return (
|
|
118
|
+
f"hidden_size={self.hidden_size}, num_heads={self.num_heads}, "
|
|
119
|
+
f"head_dim={self.head_dim}, kernel_size={self.kernel_size}, "
|
|
120
|
+
f"dilation={self.dilation}, scale={self.scale}, dropout={self.dropout_p}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class LigerFusedNeighborhoodAttentionLayer(nn.Module):
|
|
125
|
+
"""
|
|
126
|
+
A complete neighborhood attention layer with layer norm and residual connection.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
hidden_size (int): The hidden dimension size
|
|
130
|
+
num_heads (int): Number of attention heads
|
|
131
|
+
kernel_size (int): Size of the neighborhood window (default: 7)
|
|
132
|
+
dilation (int): Dilation factor for the neighborhood (default: 1)
|
|
133
|
+
bias (bool): Whether to use bias in linear projections (default: True)
|
|
134
|
+
dropout (float): Dropout probability (default: 0.0)
|
|
135
|
+
layer_norm_eps (float): Epsilon for layer normalization (default: 1e-5)
|
|
136
|
+
scale (Optional[float]): Scaling factor for attention scores (default: None)
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
hidden_size: int,
|
|
142
|
+
num_heads: int,
|
|
143
|
+
kernel_size: int = 7,
|
|
144
|
+
dilation: int = 1,
|
|
145
|
+
bias: bool = True,
|
|
146
|
+
dropout: float = 0.0,
|
|
147
|
+
layer_norm_eps: float = 1e-5,
|
|
148
|
+
scale: Optional[float] = None,
|
|
149
|
+
):
|
|
150
|
+
super().__init__()
|
|
151
|
+
|
|
152
|
+
self.attention = LigerFusedNeighborhoodAttention(
|
|
153
|
+
hidden_size=hidden_size,
|
|
154
|
+
num_heads=num_heads,
|
|
155
|
+
kernel_size=kernel_size,
|
|
156
|
+
dilation=dilation,
|
|
157
|
+
bias=bias,
|
|
158
|
+
dropout=dropout,
|
|
159
|
+
scale=scale,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
|
163
|
+
|
|
164
|
+
if dropout > 0.0:
|
|
165
|
+
self.dropout = nn.Dropout(dropout)
|
|
166
|
+
else:
|
|
167
|
+
self.dropout = None
|
|
168
|
+
|
|
169
|
+
def forward(
|
|
170
|
+
self,
|
|
171
|
+
hidden_states: torch.Tensor,
|
|
172
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
173
|
+
) -> torch.Tensor:
|
|
174
|
+
"""
|
|
175
|
+
Forward pass with residual connection and layer normalization.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
|
179
|
+
attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
|
|
183
|
+
"""
|
|
184
|
+
normed_hidden_states = self.layer_norm(hidden_states)
|
|
185
|
+
|
|
186
|
+
attn_output = self.attention(normed_hidden_states, attention_mask)
|
|
187
|
+
|
|
188
|
+
if self.dropout is not None:
|
|
189
|
+
attn_output = self.dropout(attn_output)
|
|
190
|
+
|
|
191
|
+
output = hidden_states + attn_output
|
|
192
|
+
|
|
193
|
+
return output
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class LigerFusedNeighborhoodAttentionConfig:
|
|
197
|
+
"""
|
|
198
|
+
Configuration class for Fused Neighborhood Attention.
|
|
199
|
+
|
|
200
|
+
This can be used to easily configure neighborhood attention parameters
|
|
201
|
+
for different model architectures.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def __init__(
|
|
205
|
+
self,
|
|
206
|
+
hidden_size: int = 768,
|
|
207
|
+
num_heads: int = 12,
|
|
208
|
+
kernel_size: int = 7,
|
|
209
|
+
dilation: int = 1,
|
|
210
|
+
bias: bool = True,
|
|
211
|
+
dropout: float = 0.0,
|
|
212
|
+
layer_norm_eps: float = 1e-5,
|
|
213
|
+
scale: Optional[float] = None,
|
|
214
|
+
):
|
|
215
|
+
self.hidden_size = hidden_size
|
|
216
|
+
self.num_heads = num_heads
|
|
217
|
+
self.kernel_size = kernel_size
|
|
218
|
+
self.dilation = dilation
|
|
219
|
+
self.bias = bias
|
|
220
|
+
self.dropout = dropout
|
|
221
|
+
self.layer_norm_eps = layer_norm_eps
|
|
222
|
+
self.scale = scale
|
|
223
|
+
|
|
224
|
+
def to_dict(self):
|
|
225
|
+
return {
|
|
226
|
+
"hidden_size": self.hidden_size,
|
|
227
|
+
"num_heads": self.num_heads,
|
|
228
|
+
"kernel_size": self.kernel_size,
|
|
229
|
+
"dilation": self.dilation,
|
|
230
|
+
"bias": self.bias,
|
|
231
|
+
"dropout": self.dropout,
|
|
232
|
+
"layer_norm_eps": self.layer_norm_eps,
|
|
233
|
+
"scale": self.scale,
|
|
234
|
+
}
|