liger-kernel-nightly 0.5.10.dev20250524022630__py3-none-any.whl → 0.5.10.dev20250526154149__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,207 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from torch.nn.modules.utils import _pair
7
+
8
+ from liger_kernel.ops.softmax import _softmax_forward
9
+ from liger_kernel.ops.sparsemax import _sparsemax_backward
10
+ from liger_kernel.ops.sparsemax import _sparsemax_forward
11
+ from liger_kernel.ops.utils import calculate_settings
12
+ from liger_kernel.ops.utils import ensure_contiguous
13
+
14
+
15
+ @triton.jit
16
+ def _mask_fwd_kernel(
17
+ scores_ptr,
18
+ out_ptr,
19
+ stride_b,
20
+ stride_m,
21
+ stride_n,
22
+ L,
23
+ mask_val: tl.constexpr,
24
+ BLOCK: tl.constexpr,
25
+ num_warps: tl.constexpr,
26
+ ):
27
+ row_block = tl.program_id(0)
28
+ col_block = tl.program_id(1)
29
+ batch_id = tl.program_id(2)
30
+
31
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
32
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
33
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
34
+
35
+ base = scores_ptr + batch_id * stride_b
36
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
37
+ future = col_idx[None, :] > row_idx[:, None]
38
+ mask_load = in_bounds & ~future
39
+ out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca")
40
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs")
41
+
42
+
43
+ @triton.jit
44
+ def _mask_bwd_kernel(
45
+ grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr
46
+ ):
47
+ row_block = tl.program_id(0)
48
+ col_block = tl.program_id(1)
49
+ batch_id = tl.program_id(2)
50
+
51
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
52
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
53
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
54
+
55
+ base = grad_in_ptr + batch_id * stride_b
56
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
57
+ grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca")
58
+
59
+ future = col_idx[None, :] > row_idx[:, None]
60
+ zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype)
61
+ out = tl.where(future, zero, grad_vals)
62
+
63
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb")
64
+
65
+
66
+ def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor:
67
+ *batch, L, _ = scores.shape
68
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
69
+ scores_f = scores.view(N, L, L)
70
+ out = torch.empty_like(scores_f)
71
+
72
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
73
+ BLOCK_SIZE, num_warps = calculate_settings(L)
74
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
75
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps)
76
+ return out.view(*batch, L, L)
77
+
78
+
79
+ def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor:
80
+ *batch, L, _ = grad.shape
81
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
82
+ grad_f = grad.view(N, L, L)
83
+ out = torch.empty_like(grad_f)
84
+
85
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
86
+ BLOCK_SIZE, num_warps = calculate_settings(L)
87
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
88
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
89
+ return out.view(*batch, L, L)
90
+
91
+
92
+ def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor:
93
+ *batch, L, _ = scores.shape
94
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
95
+ scores_f = scores.view(N, L, L)
96
+ out = torch.empty_like(scores_f)
97
+
98
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
99
+ BLOCK_SIZE, num_warps = calculate_settings(L)
100
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
101
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps)
102
+ return out.view(*batch, L, L)
103
+
104
+
105
+ def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor:
106
+ *batch, L, _ = grad.shape
107
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
108
+ grad_f = grad.view(N, L, L)
109
+ out = torch.empty_like(grad_f)
110
+
111
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
112
+ BLOCK_SIZE, num_warps = calculate_settings(L)
113
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
114
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
115
+ return out.view(*batch, L, L)
116
+
117
+
118
+ class LigerMultiTokenAttentionFunction(torch.autograd.Function):
119
+ @staticmethod
120
+ @ensure_contiguous
121
+ def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False):
122
+ scores_inf = _mask_inf_forward(scores)
123
+
124
+ out_flat_sparse = None
125
+ activation_output = None
126
+
127
+ ctx.sparse = sparse
128
+
129
+ if sparse:
130
+ if scores_inf.dtype != torch.float32:
131
+ raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores")
132
+ probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1)
133
+ activation_output = probs_sparse
134
+ ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias)
135
+ ctx.out_flat_sparse_saved = True
136
+ else:
137
+ probs_softmax, _, _, _ = _softmax_forward(scores_inf)
138
+ activation_output = probs_softmax
139
+ ctx.save_for_backward(scores_inf, activation_output, weight, bias)
140
+ ctx.out_flat_sparse_saved = False
141
+
142
+ out_conv = F.conv2d(
143
+ activation_output,
144
+ weight,
145
+ bias,
146
+ stride=stride,
147
+ padding=padding,
148
+ dilation=dilation,
149
+ groups=groups,
150
+ )
151
+
152
+ out = _mask_zero_forward(out_conv)
153
+
154
+ ctx.stride = _pair(stride)
155
+ ctx.padding = _pair(padding)
156
+ ctx.dilation = _pair(dilation)
157
+ ctx.groups = groups
158
+ ctx.dim = -1
159
+
160
+ return out
161
+
162
+ @staticmethod
163
+ @ensure_contiguous
164
+ def backward(ctx, grad_out):
165
+ if ctx.out_flat_sparse_saved:
166
+ scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors
167
+ else:
168
+ scores_inf, activation_output, weight, bias = ctx.saved_tensors
169
+ out_flat_sparse = None
170
+
171
+ use_sparsemax = ctx.sparse
172
+ dim = ctx.dim
173
+ stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
174
+
175
+ grad_conv = _mask_zero_backward(grad_out)
176
+
177
+ grad_probs = F.conv_transpose2d(
178
+ grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups
179
+ )
180
+
181
+ grad_weight = torch.nn.grad.conv2d_weight(
182
+ input=activation_output,
183
+ weight_size=weight.shape,
184
+ grad_output=grad_conv,
185
+ stride=stride,
186
+ padding=padding,
187
+ dilation=dilation,
188
+ groups=groups,
189
+ )
190
+ grad_bias = None
191
+ if bias is not None:
192
+ grad_bias = grad_conv.sum(dim=(0, 2, 3))
193
+
194
+ grad_scores_inf = None
195
+ if use_sparsemax:
196
+ if not ctx.out_flat_sparse_saved or out_flat_sparse is None:
197
+ raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.")
198
+ grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim)
199
+ else:
200
+ grad_probs_cont = grad_probs
201
+ probs_cont = activation_output
202
+ dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True)
203
+ grad_scores_inf = probs_cont * (grad_probs_cont - dot)
204
+
205
+ grad_scores = _mask_inf_backward(grad_scores_inf)
206
+
207
+ return (grad_scores, grad_weight, grad_bias, None, None, None, None, None)
@@ -0,0 +1,201 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _softmax_single_block_forward_kernel(
13
+ Y_ptr,
14
+ Y_row_stride,
15
+ X_ptr,
16
+ X_row_stride,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ ):
20
+ row_id = tl.program_id(0)
21
+ offs = tl.arange(0, BLOCK_SIZE)
22
+ mask = offs < n_cols
23
+
24
+ x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
25
+ m = tl.max(x, axis=0)
26
+ e = tl.exp(x - m)
27
+ d = tl.sum(e, axis=0)
28
+ y = e / d
29
+ tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
30
+
31
+
32
+ @triton.jit
33
+ def _softmax_multi_block_forward_kernel(
34
+ Y_ptr,
35
+ Y_row_stride,
36
+ X_ptr,
37
+ X_row_stride,
38
+ n_cols,
39
+ BLOCK_SIZE: tl.constexpr,
40
+ ):
41
+ row_id = tl.program_id(0)
42
+ offs = tl.arange(0, BLOCK_SIZE)
43
+
44
+ m = tl.float32(-float("inf"))
45
+ d = tl.float32(0.0)
46
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
47
+ idx = start + offs
48
+ mask = idx < n_cols
49
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
50
+ blk_max = tl.max(xblk, axis=0)
51
+ new_m = tl.max(m, blk_max)
52
+ d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
53
+ m = new_m
54
+
55
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
56
+ idx = start + offs
57
+ mask = idx < n_cols
58
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
59
+ yblk = tl.exp(xblk - m) / d
60
+ tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
61
+
62
+
63
+ @triton.jit
64
+ def _softmax_single_block_backward_kernel(
65
+ dy_ptr,
66
+ dy_stride,
67
+ y_ptr,
68
+ y_stride,
69
+ dx_ptr,
70
+ dx_stride,
71
+ n_cols,
72
+ BLOCK_SIZE: tl.constexpr,
73
+ ):
74
+ row_id = tl.program_id(0)
75
+ offs = tl.arange(0, BLOCK_SIZE)
76
+ mask = offs < n_cols
77
+
78
+ dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
79
+ y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
80
+ dot = tl.sum(dy * y, axis=0)
81
+ dx = y * (dy - dot)
82
+ tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
83
+
84
+
85
+ @triton.jit
86
+ def _softmax_multi_block_backward_kernel(
87
+ dy_ptr,
88
+ dy_stride,
89
+ y_ptr,
90
+ y_stride,
91
+ dx_ptr,
92
+ dx_stride,
93
+ n_cols,
94
+ BLOCK_SIZE: tl.constexpr,
95
+ ):
96
+ row_id = tl.program_id(0)
97
+ offs = tl.arange(0, BLOCK_SIZE)
98
+ acc = tl.float32(0.0)
99
+
100
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
101
+ idx = start + offs
102
+ mask = idx < n_cols
103
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
104
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
105
+ acc += tl.sum(dy_blk * y_blk, axis=0)
106
+
107
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
108
+ idx = start + offs
109
+ mask = idx < n_cols
110
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
111
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
112
+ dx_blk = y_blk * (dy_blk - acc)
113
+ tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
114
+
115
+
116
+ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
117
+ *batch, n_cols = x.shape
118
+ x2d = x.contiguous().view(-1, n_cols)
119
+ n_rows = x2d.shape[0]
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ y2d = torch.empty_like(x2d)
123
+
124
+ if n_cols <= BLOCK_SIZE:
125
+ _softmax_single_block_forward_kernel[(n_rows,)](
126
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
127
+ )
128
+ multi_block_launch = False
129
+ else:
130
+ _softmax_multi_block_forward_kernel[(n_rows,)](
131
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
132
+ )
133
+ multi_block_launch = True
134
+
135
+ return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
136
+
137
+
138
+ def _softmax_backward(
139
+ dy: torch.Tensor,
140
+ y: torch.Tensor,
141
+ BLOCK_SIZE: int,
142
+ num_warps: int,
143
+ multi_block_launch: bool,
144
+ ) -> torch.Tensor:
145
+ *batch, n_cols = dy.shape
146
+ dy2d = dy.contiguous().view(-1, n_cols)
147
+ y2d = y.contiguous().view(-1, n_cols)
148
+ n_rows = dy2d.shape[0]
149
+ dx2d = torch.empty_like(dy2d)
150
+
151
+ if not multi_block_launch and n_cols <= BLOCK_SIZE:
152
+ _softmax_single_block_backward_kernel[(n_rows,)](
153
+ dy2d,
154
+ dy2d.stride(0),
155
+ y2d,
156
+ y2d.stride(0),
157
+ dx2d,
158
+ dx2d.stride(0),
159
+ n_cols,
160
+ BLOCK_SIZE=BLOCK_SIZE,
161
+ num_warps=num_warps,
162
+ )
163
+ else:
164
+ _softmax_multi_block_backward_kernel[(n_rows,)](
165
+ dy2d,
166
+ dy2d.stride(0),
167
+ y2d,
168
+ y2d.stride(0),
169
+ dx2d,
170
+ dx2d.stride(0),
171
+ n_cols,
172
+ BLOCK_SIZE=BLOCK_SIZE,
173
+ num_warps=num_warps,
174
+ )
175
+
176
+ return dx2d.view(*batch, n_cols)
177
+
178
+
179
+ class LigerSoftmaxFunction(torch.autograd.Function):
180
+ @staticmethod
181
+ @ensure_contiguous
182
+ def forward(ctx, input_: torch.Tensor):
183
+ y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
184
+ ctx.save_for_backward(y)
185
+ ctx.BLOCK_SIZE = BLOCK_SIZE
186
+ ctx.num_warps = num_warps
187
+ ctx.multi_block_launch = multi_block_launch
188
+ return y
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output):
193
+ (y,) = ctx.saved_tensors
194
+ dx = _softmax_backward(
195
+ grad_output,
196
+ y,
197
+ ctx.BLOCK_SIZE,
198
+ ctx.num_warps,
199
+ ctx.multi_block_launch,
200
+ )
201
+ return dx
@@ -1,3 +1,5 @@
1
+ from typing import Tuple
2
+
1
3
  import torch
2
4
  import triton
3
5
  import triton.language as tl
@@ -105,63 +107,73 @@ def _sparsemax_backward_kernel(
105
107
  tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
106
108
 
107
109
 
110
+ def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ if dim < 0:
112
+ dim += x.dim()
113
+ x_sw = x.transpose(dim, -1).contiguous()
114
+ n_cols = x_sw.size(-1)
115
+ n_rows = x_sw.numel() // n_cols
116
+ x_flat = x_sw.view(n_rows, n_cols)
117
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
118
+
119
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
120
+ out_flat = torch.empty_like(x_flat)
121
+ grid = (n_rows,)
122
+ _sparsemax_forward_kernel[grid](
123
+ x_flat,
124
+ x_flat.stride(0),
125
+ x_sorted_flat,
126
+ x_sorted_flat.stride(0),
127
+ out_flat,
128
+ out_flat.stride(0),
129
+ n_cols,
130
+ BLOCK_SIZE=BLOCK_SIZE,
131
+ num_warps=num_warps,
132
+ )
133
+
134
+ y = out_flat.view_as(x_sw).transpose(dim, -1)
135
+ return y, out_flat
136
+
137
+
138
+ def _sparsemax_backward(
139
+ grad_out: torch.Tensor,
140
+ out_flat: torch.Tensor,
141
+ dim: int,
142
+ ) -> torch.Tensor:
143
+ grad_sw = grad_out.transpose(dim, -1).contiguous()
144
+ n_cols = grad_sw.size(-1)
145
+ n_rows = grad_sw.numel() // n_cols
146
+ go_flat = grad_sw.view(n_rows, n_cols)
147
+
148
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
149
+ dx_flat = torch.empty_like(go_flat)
150
+ grid = (n_rows,)
151
+ _sparsemax_backward_kernel[grid](
152
+ out_flat,
153
+ go_flat,
154
+ dx_flat,
155
+ out_flat.stride(0),
156
+ n_cols,
157
+ BLOCK_SIZE=BLOCK_SIZE,
158
+ num_warps=num_warps,
159
+ )
160
+
161
+ dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
162
+ return dx
163
+
164
+
108
165
  class LigerSparsemaxFunction(torch.autograd.Function):
109
166
  @staticmethod
110
167
  @ensure_contiguous
111
168
  def forward(ctx, x: torch.Tensor, dim: int):
112
- if dim < 0:
113
- dim += x.dim()
114
- ctx.dim = dim
115
-
116
- x_sw = x.transpose(dim, -1).contiguous()
117
- n_cols = x_sw.size(-1)
118
- n_rows = x_sw.numel() // n_cols
119
- x_flat = x_sw.view(n_rows, n_cols)
120
-
121
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
- out_flat = torch.empty_like(x_flat)
123
- grid = (n_rows,)
124
-
125
- x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
126
-
127
- _sparsemax_forward_kernel[grid](
128
- x_flat,
129
- x_flat.stride(0),
130
- x_sorted_flat,
131
- x_sorted_flat.stride(0),
132
- out_flat,
133
- out_flat.stride(0),
134
- n_cols,
135
- BLOCK_SIZE=BLOCK_SIZE,
136
- num_warps=num_warps,
137
- )
138
-
169
+ y, out_flat = _sparsemax_forward(x, dim)
139
170
  ctx.save_for_backward(out_flat)
140
- return out_flat.view_as(x_sw).transpose(dim, -1)
171
+ ctx.dim = dim
172
+ return y
141
173
 
142
174
  @staticmethod
143
175
  @ensure_contiguous
144
176
  def backward(ctx, grad_out: torch.Tensor):
145
177
  (out_flat,) = ctx.saved_tensors
146
- dim = ctx.dim
147
-
148
- go_sw = grad_out.transpose(dim, -1).contiguous()
149
- n_cols = go_sw.size(-1)
150
- n_rows = go_sw.numel() // n_cols
151
- go_flat = go_sw.view(n_rows, n_cols)
152
-
153
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
154
- gi_flat = torch.empty_like(go_flat)
155
- grid = (n_rows,)
156
-
157
- _sparsemax_backward_kernel[grid](
158
- out_flat,
159
- go_flat,
160
- gi_flat,
161
- out_flat.stride(0),
162
- n_cols,
163
- BLOCK_SIZE=BLOCK_SIZE,
164
- num_warps=num_warps,
165
- )
166
-
167
- return gi_flat.view_as(go_sw).transpose(dim, -1), None
178
+ dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
179
+ return dx, None
@@ -9,9 +9,11 @@ from liger_kernel.ops.group_norm import LigerGroupNormFunction
9
9
  from liger_kernel.ops.jsd import LigerJSDFunction
10
10
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
11
11
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
12
+ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
12
13
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
14
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
15
  from liger_kernel.ops.rope import LigerRopeFunction
16
+ from liger_kernel.ops.softmax import LigerSoftmaxFunction
15
17
  from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
16
18
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17
19
  from liger_kernel.ops.tvd import LigerTVDLossFunction
@@ -167,6 +169,34 @@ def liger_sparsemax(
167
169
  return LigerSparsemaxFunction.apply(input, dim)
168
170
 
169
171
 
172
+ def liger_multi_token_attention(
173
+ scores,
174
+ weight,
175
+ bias=None,
176
+ stride: int = 1,
177
+ padding: int = 0,
178
+ dilation: int = 1,
179
+ groups: int = 1,
180
+ sparse: bool = False,
181
+ ):
182
+ """
183
+ Functional interface for multi-token attention.
184
+
185
+ Args:
186
+ scores: Input tensor of shape (B, C_in, L, L)
187
+ weight: Convolution weight tensor of shape (C_out, C_in // groups, K, K)
188
+ bias: Optional bias tensor of shape (C_out,)
189
+ stride: Stride for the convolution (default: 1)
190
+ padding: Padding for the convolution (default: 0)
191
+ dilation: Dilation factor for the convolution (default: 1)
192
+ groups: Number of groups for the convolution (default: 1)
193
+ sparse: Specifies if input tensors are expected to be sparse (default: False)
194
+ Returns:
195
+ Output tensor after applying multi-token attention.
196
+ """
197
+ return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
198
+
199
+
170
200
  def liger_tvd(
171
201
  input,
172
202
  target,
@@ -203,5 +233,9 @@ def liger_swiglu(a, b):
203
233
  return LigerSiLUMulFunction.apply(a, b)
204
234
 
205
235
 
236
+ def liger_softmax(x):
237
+ return LigerSoftmaxFunction.apply(x)
238
+
239
+
206
240
  def liger_dyt(x, alpha, gamma, beta):
207
241
  return LigerDyTFunction.apply(x, alpha, gamma, beta)
@@ -0,0 +1,64 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from torch.nn.modules.utils import _pair
7
+
8
+ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
9
+
10
+
11
+ class LigerMultiTokenAttention(nn.Module):
12
+ """
13
+ Multi-Token Attention:
14
+ out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
15
+
16
+ Reference: https://arxiv.org/pdf/2504.00927
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ in_channels: int,
22
+ out_channels: int,
23
+ kernel_size: int,
24
+ stride: int = 1,
25
+ padding: int = 0,
26
+ dilation: int = 1,
27
+ groups: int = 1,
28
+ bias: bool = True,
29
+ sparse: bool = False,
30
+ ):
31
+ super().__init__()
32
+ self.in_channels = in_channels
33
+ self.out_channels = out_channels
34
+ self.kernel_size = _pair(kernel_size)
35
+ self.stride = _pair(stride)
36
+ self.padding = _pair(padding)
37
+ self.dilation = _pair(dilation)
38
+ self.groups = groups
39
+ self.sparse = sparse
40
+
41
+ self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size))
42
+ if bias:
43
+ self.bias = nn.Parameter(torch.empty(out_channels))
44
+ else:
45
+ self.register_parameter("bias", None)
46
+
47
+ self.reset_parameters()
48
+
49
+ def reset_parameters(self):
50
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
51
+ if self.bias is not None:
52
+ nn.init.zeros_(self.bias)
53
+
54
+ def forward(self, scores: torch.Tensor) -> torch.Tensor:
55
+ return LigerMultiTokenAttentionFunction.apply(
56
+ scores,
57
+ self.weight,
58
+ self.bias,
59
+ self.stride,
60
+ self.padding,
61
+ self.dilation,
62
+ self.groups,
63
+ self.sparse,
64
+ )
@@ -0,0 +1,12 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.softmax import LigerSoftmaxFunction
5
+
6
+
7
+ class LigerKernelSoftmax(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def forward(self, x: torch.Tensor):
12
+ return LigerSoftmaxFunction.apply(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250524022630
3
+ Version: 0.5.10.dev20250526154149
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -26,10 +26,12 @@ liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0
26
26
  liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
27
27
  liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
28
28
  liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
29
+ liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
29
30
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
30
31
  liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
31
32
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
32
- liger_kernel/ops/sparsemax.py,sha256=t7JWIyzq1piikXUufayFzsfkzVaCYU-hXPuMs7839pk,4850
33
+ liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
34
+ liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
33
35
  liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
34
36
  liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
35
37
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
@@ -40,7 +42,7 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
40
42
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
41
43
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
42
44
  liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
43
- liger_kernel/transformers/functional.py,sha256=2YBfvtdU1GRZuRpJhHgJXeGYa1RvmO6-qQvrKQrLJK4,5259
45
+ liger_kernel/transformers/functional.py,sha256=QmnAFpRgIbp9Rzlfp8QibwiEbf5BUcANxfY68an7o8c,6444
44
46
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
45
47
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
46
48
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
@@ -51,9 +53,11 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
51
53
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
52
54
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
53
55
  liger_kernel/transformers/monkey_patch.py,sha256=DKv5-4KyXLiVhAJ9WVFv1I1i1DzjaudTrhqx6EVYViU,74505
56
+ liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
54
57
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
55
58
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
56
59
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
60
+ liger_kernel/transformers/softmax.py,sha256=u7bFo35-cjaAm9of6-DLzmkaNFELOM-9AgyrcvUPifw,270
57
61
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
58
62
  liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
59
63
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
@@ -82,9 +86,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
82
86
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
83
87
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
84
88
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
85
- liger_kernel_nightly-0.5.10.dev20250524022630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
86
- liger_kernel_nightly-0.5.10.dev20250524022630.dist-info/METADATA,sha256=kjNZA78siLFmGihrwvXrUIiLnTNaJoZglOkbtihcynk,24113
87
- liger_kernel_nightly-0.5.10.dev20250524022630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
88
- liger_kernel_nightly-0.5.10.dev20250524022630.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
89
- liger_kernel_nightly-0.5.10.dev20250524022630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
90
- liger_kernel_nightly-0.5.10.dev20250524022630.dist-info/RECORD,,
89
+ liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
90
+ liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/METADATA,sha256=0CXMJx6ef3SurofjUlAWwMaj-prwFf-xg_nMo-n6UPE,24113
91
+ liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
92
+ liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
93
+ liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
94
+ liger_kernel_nightly-0.5.10.dev20250526154149.dist-info/RECORD,,