liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  6. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  7. liger_kernel/ops/dyt.py +111 -179
  8. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  9. liger_kernel/ops/geglu.py +1 -1
  10. liger_kernel/ops/grpo_loss.py +310 -0
  11. liger_kernel/ops/multi_token_attention.py +207 -0
  12. liger_kernel/ops/rms_norm.py +265 -54
  13. liger_kernel/ops/softmax.py +201 -0
  14. liger_kernel/ops/sparsemax.py +179 -0
  15. liger_kernel/ops/swiglu.py +1 -1
  16. liger_kernel/transformers/__init__.py +8 -0
  17. liger_kernel/transformers/dyt.py +5 -3
  18. liger_kernel/transformers/fsdp.py +55 -0
  19. liger_kernel/transformers/functional.py +70 -0
  20. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  21. liger_kernel/transformers/grpo_loss.py +98 -0
  22. liger_kernel/transformers/model/gemma.py +25 -16
  23. liger_kernel/transformers/model/gemma2.py +27 -14
  24. liger_kernel/transformers/model/gemma3.py +62 -106
  25. liger_kernel/transformers/model/glm4.py +16 -13
  26. liger_kernel/transformers/model/llama.py +81 -18
  27. liger_kernel/transformers/model/llama4.py +108 -0
  28. liger_kernel/transformers/model/llava.py +95 -132
  29. liger_kernel/transformers/model/mistral.py +13 -14
  30. liger_kernel/transformers/model/mixtral.py +16 -15
  31. liger_kernel/transformers/model/mllama.py +16 -14
  32. liger_kernel/transformers/model/olmo2.py +16 -13
  33. liger_kernel/transformers/model/paligemma.py +8 -9
  34. liger_kernel/transformers/model/phi3.py +25 -16
  35. liger_kernel/transformers/model/qwen2.py +24 -15
  36. liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
  37. liger_kernel/transformers/model/qwen2_vl.py +38 -106
  38. liger_kernel/transformers/model/qwen3.py +11 -9
  39. liger_kernel/transformers/model/qwen3_moe.py +132 -0
  40. liger_kernel/transformers/monkey_patch.py +424 -81
  41. liger_kernel/transformers/multi_token_attention.py +64 -0
  42. liger_kernel/transformers/rms_norm.py +40 -4
  43. liger_kernel/transformers/softmax.py +12 -0
  44. liger_kernel/transformers/sparsemax.py +16 -0
  45. liger_kernel/transformers/swiglu.py +21 -0
  46. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  47. liger_kernel/utils.py +11 -0
  48. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
  49. liger_kernel-0.6.0.dist-info/RECORD +97 -0
  50. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  51. liger_kernel/transformers/gema3_rms.py +0 -8
  52. liger_kernel-0.5.9.dist-info/RECORD +0 -84
  53. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py CHANGED
@@ -40,7 +40,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
40
40
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
41
  tanh_result = tanh(tanh_arg)
42
42
  geglu_a = 0.5 * a_row * (1 + tanh_result)
43
- c_row = geglu_a * b_row
43
+ c_row = geglu_a.cast(b_row.dtype) * b_row
44
44
  tl.store(c + col_offsets, c_row, mask=mask)
45
45
 
46
46
 
@@ -0,0 +1,310 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _selective_log_softmax_kernel(
8
+ LOGITS,
9
+ INPUT_IDS,
10
+ LOG_P,
11
+ MASK,
12
+ TEMPERATURE,
13
+ stride_input_ids_b,
14
+ L: tl.constexpr,
15
+ N: tl.constexpr,
16
+ BLOCK_N: tl.constexpr = 4096,
17
+ ):
18
+ off_b = tl.program_id(0).cast(tl.int64)
19
+ off_l = tl.program_id(1).cast(tl.int64)
20
+
21
+ LOGITS += off_b * (L + 1) * N + off_l * N
22
+ INPUT_IDS += off_b * stride_input_ids_b + off_l
23
+ LOG_P += off_b * L + off_l
24
+
25
+ if MASK is not None:
26
+ MASK += off_b * stride_input_ids_b + off_l
27
+ not_skip = tl.load(MASK)
28
+ if not_skip == 0:
29
+ return
30
+
31
+ m_i = float("-inf")
32
+ l_i = 0.0
33
+ for start in range(0, N, BLOCK_N):
34
+ cols = start + tl.arange(0, BLOCK_N)
35
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
36
+ new_m_i = tl.maximum(m_i, tl.max(logits))
37
+ alpha = tl.exp(m_i - new_m_i)
38
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
39
+ m_i = new_m_i
40
+ lse = m_i + tl.log(l_i)
41
+
42
+ ids = tl.load(INPUT_IDS)
43
+ x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
44
+ logp = x - lse
45
+ tl.store(LOG_P, logp)
46
+
47
+
48
+ # compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
49
+ @torch.no_grad
50
+ def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
51
+ assert logits.is_contiguous()
52
+ B, L_ADD_1, N = logits.shape
53
+ L = L_ADD_1 - 1
54
+ input_ids = input_ids[:, -L:]
55
+ if mask is not None:
56
+ mask = mask[:, -L:]
57
+ log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
58
+ kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
59
+ _selective_log_softmax_kernel[(B, L)](
60
+ logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
61
+ )
62
+ return log_p
63
+
64
+
65
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
66
+ # for BLOCK_N in [2048, 4096, 8192]
67
+ # for ns in [1, 2, 4]
68
+ # for nw in [1, 2, 4, 8, 16]],
69
+ # key=['N'])
70
+ @triton.jit
71
+ def _grpo_loss_fwd_kernel(
72
+ LOGITS,
73
+ OLD_LOGP,
74
+ REF_LOGP,
75
+ INPUT_IDS,
76
+ COMPLETION_MASK,
77
+ ADVANTAGES,
78
+ LOSS,
79
+ LSE,
80
+ KL,
81
+ IS_CLIPPED,
82
+ TEMPERATURE,
83
+ BETA: tl.constexpr,
84
+ EPS_LOW,
85
+ EPS_HIGH,
86
+ L: tl.constexpr,
87
+ N: tl.constexpr,
88
+ BLOCK_N: tl.constexpr = 4096,
89
+ ):
90
+ off_b = tl.program_id(0).cast(tl.int64)
91
+ off_l = tl.program_id(1).cast(tl.int64)
92
+
93
+ if COMPLETION_MASK is not None:
94
+ COMPLETION_MASK += off_b * L + off_l
95
+ not_skip = tl.load(COMPLETION_MASK)
96
+ if not_skip == 0:
97
+ return
98
+
99
+ LOGITS += off_b * (L + 1) * N + off_l * N
100
+ INPUT_IDS += off_b * L + off_l
101
+ ADVANTAGES += off_b
102
+ LOSS += off_b * L + off_l
103
+ LSE += off_b * L + off_l
104
+ IS_CLIPPED += off_b * L + off_l
105
+
106
+ m_i = float("-inf")
107
+ l_i = 0.0
108
+ for start in range(0, N, BLOCK_N):
109
+ cols = start + tl.arange(0, BLOCK_N)
110
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
111
+ new_m_i = tl.maximum(m_i, tl.max(logits))
112
+ alpha = tl.exp(m_i - new_m_i)
113
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
114
+ m_i = new_m_i
115
+ lse = m_i + tl.log(l_i)
116
+
117
+ idx = tl.load(INPUT_IDS)
118
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
119
+ logp = x - lse
120
+ if OLD_LOGP is None:
121
+ old_logp = logp
122
+ else:
123
+ OLD_LOGP += off_b * L + off_l
124
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
125
+ coef_1 = tl.exp(logp - old_logp)
126
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
128
+ per_token_loss1 = coef_1 * advantage
129
+ per_token_loss2 = coef_2 * advantage
130
+ per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
+ is_clipped = per_token_loss1 < per_token_loss2
132
+
133
+ if BETA != 0.0:
134
+ REF_LOGP += off_b * L + off_l
135
+ KL += off_b * L + off_l
136
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
137
+ kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
138
+ per_token_loss += BETA * kl
139
+ tl.store(KL, kl)
140
+
141
+ tl.store(LOSS, per_token_loss)
142
+ tl.store(LSE, lse)
143
+ tl.store(IS_CLIPPED, is_clipped)
144
+
145
+
146
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
147
+ # for BLOCK_N in [2048, 4096, 8192]
148
+ # for ns in [1, 2, 4]
149
+ # for nw in [1, 2, 4, 8, 16]],
150
+ # key=['N'])
151
+ @triton.jit
152
+ def _grpo_loss_bwd_kernel(
153
+ DLOSS,
154
+ DLOGITS,
155
+ LOGITS,
156
+ OLD_LOGP,
157
+ REF_LOGP,
158
+ INPUT_IDS,
159
+ ADVANTAGES,
160
+ COMPLETION_MASK,
161
+ LSE,
162
+ TEMPERATURE,
163
+ BETA: tl.constexpr,
164
+ EPS_LOW,
165
+ EPS_HIGH,
166
+ loss_stride0,
167
+ loss_stride1,
168
+ L: tl.constexpr,
169
+ N: tl.constexpr,
170
+ BLOCK_N: tl.constexpr = 4096,
171
+ ):
172
+ off_b = tl.program_id(0).cast(tl.int64)
173
+ off_l = tl.program_id(1).cast(tl.int64)
174
+
175
+ DLOGITS += off_b * (L + 1) * N + off_l * N
176
+ if COMPLETION_MASK is not None:
177
+ COMPLETION_MASK += off_b * L + off_l
178
+ not_skip = tl.load(COMPLETION_MASK)
179
+ if not_skip == 0:
180
+ for start in range(0, N, BLOCK_N):
181
+ cols = tl.arange(0, BLOCK_N) + start
182
+ tl.store(DLOGITS + cols, 0.0, mask=cols < N)
183
+ return
184
+
185
+ LOGITS += off_b * (L + 1) * N + off_l * N
186
+ DLOSS += off_b * loss_stride0 + off_l * loss_stride1
187
+ INPUT_IDS += off_b * L + off_l
188
+ ADVANTAGES += off_b
189
+ LSE += off_b * L + off_l
190
+
191
+ dloss = tl.load(DLOSS).to(tl.float32)
192
+ lse = tl.load(LSE).to(tl.float32)
193
+
194
+ idx = tl.load(INPUT_IDS)
195
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
196
+ logp = x - lse
197
+ if OLD_LOGP is None:
198
+ old_logp = logp
199
+ else:
200
+ OLD_LOGP += off_b * L + off_l
201
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
202
+ coef_1 = tl.exp(logp - old_logp)
203
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
204
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
205
+ per_token_loss1 = coef_1 * advantage
206
+ per_token_loss2 = coef_2 * advantage
207
+ mask = per_token_loss2 >= per_token_loss1
208
+
209
+ dlogp = -per_token_loss1 * mask
210
+ if BETA != 0.0:
211
+ REF_LOGP += off_b * L + off_l
212
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
213
+ dlogp += BETA * (1 - tl.exp(ref_logp - logp))
214
+
215
+ dlogp = dlogp * dloss / TEMPERATURE
216
+ tl.debug_barrier()
217
+ for start_n in tl.range(0, N, BLOCK_N):
218
+ cols = start_n + tl.arange(0, BLOCK_N)
219
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
220
+ probs = tl.exp(logits - lse)
221
+ dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
222
+ tl.store(DLOGITS + cols, dlogits, mask=cols < N)
223
+
224
+
225
+ class GrpoLossFunction(torch.autograd.Function):
226
+ @staticmethod
227
+ def forward(
228
+ ctx,
229
+ logits,
230
+ old_logp,
231
+ ref_logp,
232
+ completion_ids,
233
+ advantages,
234
+ completion_mask,
235
+ temperature,
236
+ beta,
237
+ eps_low,
238
+ eps_high,
239
+ inplace,
240
+ ):
241
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
242
+ assert old_logp is None or old_logp.is_contiguous()
243
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
244
+
245
+ B, L_ADD_1, N = logits.shape
246
+ L = L_ADD_1 - 1
247
+
248
+ if completion_mask is not None:
249
+ assert completion_mask.is_contiguous()
250
+
251
+ loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
252
+ lse = torch.zeros_like(loss)
253
+ is_clipped = torch.zeros_like(loss)
254
+ kl = torch.zeros_like(loss) if beta != 0.0 else None
255
+ kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
256
+ _grpo_loss_fwd_kernel[(B, L)](
257
+ logits,
258
+ old_logp,
259
+ ref_logp,
260
+ completion_ids,
261
+ completion_mask,
262
+ advantages,
263
+ loss,
264
+ lse,
265
+ kl,
266
+ is_clipped,
267
+ temperature,
268
+ beta,
269
+ eps_low,
270
+ eps_high,
271
+ L,
272
+ N,
273
+ **kwargs,
274
+ )
275
+ ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
276
+ ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
277
+ # return loss
278
+ return loss, kl, is_clipped
279
+
280
+ @staticmethod
281
+ def backward(ctx, *args):
282
+ dloss = args[0]
283
+ # print(dloss.shape)
284
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
285
+ temperature, beta, eps_low, eps_high, inplace = ctx.infos
286
+ B, L_ADD_1, N = logits.shape
287
+ L = L_ADD_1 - 1
288
+ dlogits = logits.data if inplace else torch.empty_like(logits)
289
+ kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
290
+ _grpo_loss_bwd_kernel[(B, L)](
291
+ dloss,
292
+ dlogits,
293
+ logits,
294
+ old_logp,
295
+ ref_logp,
296
+ completion_ids,
297
+ advantages,
298
+ completion_mask,
299
+ lse,
300
+ temperature,
301
+ beta,
302
+ eps_low,
303
+ eps_high,
304
+ *dloss.stride(),
305
+ L,
306
+ N,
307
+ **kwargs,
308
+ )
309
+ dlogits[:, -1, :] = 0
310
+ return dlogits, None, None, None, None, None, None, None, None, None, None
@@ -0,0 +1,207 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from torch.nn.modules.utils import _pair
7
+
8
+ from liger_kernel.ops.softmax import _softmax_forward
9
+ from liger_kernel.ops.sparsemax import _sparsemax_backward
10
+ from liger_kernel.ops.sparsemax import _sparsemax_forward
11
+ from liger_kernel.ops.utils import calculate_settings
12
+ from liger_kernel.ops.utils import ensure_contiguous
13
+
14
+
15
+ @triton.jit
16
+ def _mask_fwd_kernel(
17
+ scores_ptr,
18
+ out_ptr,
19
+ stride_b,
20
+ stride_m,
21
+ stride_n,
22
+ L,
23
+ mask_val: tl.constexpr,
24
+ BLOCK: tl.constexpr,
25
+ num_warps: tl.constexpr,
26
+ ):
27
+ row_block = tl.program_id(0)
28
+ col_block = tl.program_id(1)
29
+ batch_id = tl.program_id(2)
30
+
31
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
32
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
33
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
34
+
35
+ base = scores_ptr + batch_id * stride_b
36
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
37
+ future = col_idx[None, :] > row_idx[:, None]
38
+ mask_load = in_bounds & ~future
39
+ out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca")
40
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs")
41
+
42
+
43
+ @triton.jit
44
+ def _mask_bwd_kernel(
45
+ grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr
46
+ ):
47
+ row_block = tl.program_id(0)
48
+ col_block = tl.program_id(1)
49
+ batch_id = tl.program_id(2)
50
+
51
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
52
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
53
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
54
+
55
+ base = grad_in_ptr + batch_id * stride_b
56
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
57
+ grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca")
58
+
59
+ future = col_idx[None, :] > row_idx[:, None]
60
+ zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype)
61
+ out = tl.where(future, zero, grad_vals)
62
+
63
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb")
64
+
65
+
66
+ def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor:
67
+ *batch, L, _ = scores.shape
68
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
69
+ scores_f = scores.view(N, L, L)
70
+ out = torch.empty_like(scores_f)
71
+
72
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
73
+ BLOCK_SIZE, num_warps = calculate_settings(L)
74
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
75
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps)
76
+ return out.view(*batch, L, L)
77
+
78
+
79
+ def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor:
80
+ *batch, L, _ = grad.shape
81
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
82
+ grad_f = grad.view(N, L, L)
83
+ out = torch.empty_like(grad_f)
84
+
85
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
86
+ BLOCK_SIZE, num_warps = calculate_settings(L)
87
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
88
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
89
+ return out.view(*batch, L, L)
90
+
91
+
92
+ def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor:
93
+ *batch, L, _ = scores.shape
94
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
95
+ scores_f = scores.view(N, L, L)
96
+ out = torch.empty_like(scores_f)
97
+
98
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
99
+ BLOCK_SIZE, num_warps = calculate_settings(L)
100
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
101
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps)
102
+ return out.view(*batch, L, L)
103
+
104
+
105
+ def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor:
106
+ *batch, L, _ = grad.shape
107
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
108
+ grad_f = grad.view(N, L, L)
109
+ out = torch.empty_like(grad_f)
110
+
111
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
112
+ BLOCK_SIZE, num_warps = calculate_settings(L)
113
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
114
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
115
+ return out.view(*batch, L, L)
116
+
117
+
118
+ class LigerMultiTokenAttentionFunction(torch.autograd.Function):
119
+ @staticmethod
120
+ @ensure_contiguous
121
+ def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False):
122
+ scores_inf = _mask_inf_forward(scores)
123
+
124
+ out_flat_sparse = None
125
+ activation_output = None
126
+
127
+ ctx.sparse = sparse
128
+
129
+ if sparse:
130
+ if scores_inf.dtype != torch.float32:
131
+ raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores")
132
+ probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1)
133
+ activation_output = probs_sparse
134
+ ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias)
135
+ ctx.out_flat_sparse_saved = True
136
+ else:
137
+ probs_softmax, _, _, _ = _softmax_forward(scores_inf)
138
+ activation_output = probs_softmax
139
+ ctx.save_for_backward(scores_inf, activation_output, weight, bias)
140
+ ctx.out_flat_sparse_saved = False
141
+
142
+ out_conv = F.conv2d(
143
+ activation_output,
144
+ weight,
145
+ bias,
146
+ stride=stride,
147
+ padding=padding,
148
+ dilation=dilation,
149
+ groups=groups,
150
+ )
151
+
152
+ out = _mask_zero_forward(out_conv)
153
+
154
+ ctx.stride = _pair(stride)
155
+ ctx.padding = _pair(padding)
156
+ ctx.dilation = _pair(dilation)
157
+ ctx.groups = groups
158
+ ctx.dim = -1
159
+
160
+ return out
161
+
162
+ @staticmethod
163
+ @ensure_contiguous
164
+ def backward(ctx, grad_out):
165
+ if ctx.out_flat_sparse_saved:
166
+ scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors
167
+ else:
168
+ scores_inf, activation_output, weight, bias = ctx.saved_tensors
169
+ out_flat_sparse = None
170
+
171
+ use_sparsemax = ctx.sparse
172
+ dim = ctx.dim
173
+ stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
174
+
175
+ grad_conv = _mask_zero_backward(grad_out)
176
+
177
+ grad_probs = F.conv_transpose2d(
178
+ grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups
179
+ )
180
+
181
+ grad_weight = torch.nn.grad.conv2d_weight(
182
+ input=activation_output,
183
+ weight_size=weight.shape,
184
+ grad_output=grad_conv,
185
+ stride=stride,
186
+ padding=padding,
187
+ dilation=dilation,
188
+ groups=groups,
189
+ )
190
+ grad_bias = None
191
+ if bias is not None:
192
+ grad_bias = grad_conv.sum(dim=(0, 2, 3))
193
+
194
+ grad_scores_inf = None
195
+ if use_sparsemax:
196
+ if not ctx.out_flat_sparse_saved or out_flat_sparse is None:
197
+ raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.")
198
+ grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim)
199
+ else:
200
+ grad_probs_cont = grad_probs
201
+ probs_cont = activation_output
202
+ dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True)
203
+ grad_scores_inf = probs_cont * (grad_probs_cont - dot)
204
+
205
+ grad_scores = _mask_inf_backward(grad_scores_inf)
206
+
207
+ return (grad_scores, grad_weight, grad_bias, None, None, None, None, None)