liger-kernel 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  6. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  7. liger_kernel/ops/geglu.py +1 -1
  8. liger_kernel/ops/layer_norm.py +126 -89
  9. liger_kernel/ops/multi_token_attention.py +207 -0
  10. liger_kernel/ops/rms_norm.py +267 -56
  11. liger_kernel/ops/rope.py +1 -1
  12. liger_kernel/ops/softmax.py +201 -0
  13. liger_kernel/ops/sparsemax.py +62 -50
  14. liger_kernel/ops/swiglu.py +1 -1
  15. liger_kernel/transformers/__init__.py +8 -0
  16. liger_kernel/transformers/functional.py +67 -0
  17. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  18. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  19. liger_kernel/transformers/model/gemma.py +25 -8
  20. liger_kernel/transformers/model/gemma2.py +27 -8
  21. liger_kernel/transformers/model/gemma3.py +63 -99
  22. liger_kernel/transformers/model/glm4.py +16 -7
  23. liger_kernel/transformers/model/llama.py +25 -7
  24. liger_kernel/transformers/model/llama4.py +108 -0
  25. liger_kernel/transformers/model/llava.py +95 -124
  26. liger_kernel/transformers/model/mistral.py +13 -8
  27. liger_kernel/transformers/model/mixtral.py +16 -7
  28. liger_kernel/transformers/model/mllama.py +16 -7
  29. liger_kernel/transformers/model/olmo2.py +16 -7
  30. liger_kernel/transformers/model/paligemma.py +8 -1
  31. liger_kernel/transformers/model/phi3.py +25 -8
  32. liger_kernel/transformers/model/qwen2.py +24 -7
  33. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  34. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  35. liger_kernel/transformers/model/qwen3.py +11 -3
  36. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  37. liger_kernel/transformers/model/smollm3.py +189 -0
  38. liger_kernel/transformers/monkey_patch.py +389 -82
  39. liger_kernel/transformers/multi_token_attention.py +64 -0
  40. liger_kernel/transformers/rms_norm.py +40 -4
  41. liger_kernel/transformers/softmax.py +12 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/METADATA +18 -14
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/RECORD +47 -37
  44. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/WHEEL +1 -1
  45. liger_kernel/transformers/gema3_rms.py +0 -8
  46. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/LICENSE +0 -0
  47. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/NOTICE +0 -0
  48. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py CHANGED
@@ -40,7 +40,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
40
40
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
41
  tanh_result = tanh(tanh_arg)
42
42
  geglu_a = 0.5 * a_row * (1 + tanh_result)
43
- c_row = geglu_a * b_row
43
+ c_row = geglu_a.cast(b_row.dtype) * b_row
44
44
  tl.store(c + col_offsets, c_row, mask=mask)
45
45
 
46
46
 
@@ -1,4 +1,3 @@
1
- import math
2
1
  import operator
3
2
 
4
3
  import torch
@@ -43,30 +42,45 @@ def _layer_norm_forward_kernel(
43
42
  https://arxiv.org/abs/1607.06450
44
43
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
44
  """
46
- row_idx = tl.program_id(0)
45
+ row_idx = tl.program_id(0).to(tl.int64)
47
46
  col_offsets = tl.arange(0, BLOCK_SIZE)
48
47
  mask = col_offsets < n_cols
49
48
 
50
- Y_ptr += row_idx * Y_row_stride
51
- X_ptr += row_idx * X_row_stride
52
- Mean_ptr += row_idx * Mean_row_stride
53
- RSTD_ptr += row_idx * RSTD_row_stride
54
-
55
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
-
59
- mean = tl.sum(X_row, axis=0) / n_cols
60
- Xmm = tl.where(mask, X_row - mean, 0)
61
- var = tl.sum(Xmm * Xmm, axis=0) / n_cols
49
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
50
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
51
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
52
+ W_f32 = W_row.to(tl.float32)
53
+ B_f32 = B_row.to(tl.float32)
54
+
55
+ # Calculate pointers for this row
56
+ row_X_ptr = X_ptr + row_idx * X_row_stride
57
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
58
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
59
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
60
+
61
+ # Load input data and convert to fp32 for numerical stability
62
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
63
+ X_f32 = X_row.to(tl.float32)
64
+
65
+ # Compute statistics in fp32 for numerical stability
66
+ n_cols_f32 = n_cols.to(tl.float32)
67
+ mean = tl.sum(X_f32, axis=0) / n_cols_f32
68
+ X_centered = X_f32 - mean
69
+ # Apply mask to variance calculation to exclude contributions from masked elements
70
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
71
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
62
72
  rstd = rsqrt(var + eps)
63
73
 
64
- tl.store(Mean_ptr, mean)
65
- tl.store(RSTD_ptr, rstd)
74
+ # Store statistics (convert back to original dtype only once)
75
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
76
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
66
77
 
67
- Y_row = Xmm * rstd * W_row + B_row
78
+ # Fused normalization and affine transformation
79
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
80
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
68
81
 
69
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
82
+ # Store output (single conversion back to original dtype)
83
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
70
84
 
71
85
 
72
86
  @triton.jit
@@ -81,73 +95,87 @@ def _layer_norm_backward_kernel(
81
95
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
96
  stride_x, # stride of each row in input
83
97
  stride_dx, # stride of each row in input grad
84
- stride_dw, # stride of each row in weights grad
85
- stride_db, # stride of each row in bias grad
86
98
  stride_dy, # stride of each row in output grad
87
- n_rows,
88
99
  n_cols,
89
- rows_per_program: tl.constexpr,
90
100
  BLOCK_SIZE: tl.constexpr,
91
101
  dtype: tl.constexpr,
102
+ atomic_dtype: tl.constexpr,
92
103
  ):
93
104
  """
94
105
  References:
95
106
  https://arxiv.org/abs/1607.06450
96
107
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
108
  """
100
- row_block_id = tl.program_id(0)
101
- row_start = row_block_id * rows_per_program
102
- row_end = min((row_block_id + 1) * rows_per_program, n_rows)
109
+ row_idx = tl.program_id(0).to(tl.int64)
103
110
  cols = tl.arange(0, BLOCK_SIZE)
104
111
  mask = cols < n_cols
105
112
 
106
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
- db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
-
109
- X_ptr += row_start * stride_x
110
- Mean_ptr += row_start
111
- RSTD_ptr += row_start
112
- DX_ptr += row_start * stride_dx
113
- DY_ptr += row_start * stride_dy
114
-
115
- for _ in range(row_start, row_end):
116
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
- mean = tl.load(Mean_ptr)
120
- rstd = tl.load(RSTD_ptr)
121
-
122
- x_hat = (x - mean) * rstd
123
- wdy = w * dy
124
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
- c2 = tl.sum(wdy, axis=0) / n_cols
126
- dx = (wdy - (x_hat * c1 + c2)) * rstd
127
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
128
-
129
- dw_row += dy * x_hat
130
- db_row += dy
131
-
132
- X_ptr += stride_x
133
- Mean_ptr += 1
134
- RSTD_ptr += 1
135
- DX_ptr += stride_dx
136
- DY_ptr += stride_dy
137
-
138
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
113
+ # Pre-load weights once (same optimization as forward pass)
114
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
115
+ w_f32 = w.to(tl.float32)
116
+ n_cols_f32 = n_cols.to(tl.float32)
117
+
118
+ # Calculate pointers for this specific row
119
+ row_X_ptr = X_ptr + row_idx * stride_x
120
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
121
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
122
+ row_Mean_ptr = Mean_ptr + row_idx
123
+ row_RSTD_ptr = RSTD_ptr + row_idx
124
+
125
+ # Load data for this row
126
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
127
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
128
+ mean = tl.load(row_Mean_ptr)
129
+ rstd = tl.load(row_RSTD_ptr)
130
+
131
+ # Convert to fp32 for numerical stability
132
+ x_f32 = x.to(tl.float32)
133
+ dy_f32 = dy.to(tl.float32)
134
+ mean_f32 = mean.to(tl.float32)
135
+ rstd_f32 = rstd.to(tl.float32)
136
+
137
+ # Compute backward pass for this row
138
+ x_hat = (x_f32 - mean_f32) * rstd_f32
139
+ wdy = w_f32 * dy_f32
140
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
141
+ c2 = tl.sum(wdy, axis=0) / n_cols_f32
142
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
143
+
144
+ # Store input gradient
145
+ tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
146
+
147
+ # Accumulate weight and bias gradients using atomic operations
148
+ dw = dy_f32 * x_hat
149
+ db = dy_f32
150
+ tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
151
+ tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
140
152
 
141
153
 
142
154
  def layer_norm_forward(X, W, B, eps):
155
+ """
156
+ Args:
157
+ X: Input tensor of shape (..., hidden_size)
158
+ W: Weight tensor of shape (hidden_size,)
159
+ B: Bias tensor of shape (hidden_size,)
160
+ eps: Small constant for numerical stability
161
+
162
+ Returns:
163
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
164
+ """
143
165
  shape = X.shape
144
166
  dim = shape[-1]
145
167
  X = X.view(-1, dim)
146
168
  n_rows, n_cols = X.shape
169
+
170
+ # Calculate optimal block size and warp configuration
147
171
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
172
+
173
+ # Allocate output tensors
148
174
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
175
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
176
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
177
+
178
+ # Validate input dimensions
151
179
  if X.shape[1] != W.shape[0]:
152
180
  raise ValueError(
153
181
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
@@ -159,7 +187,9 @@ def layer_norm_forward(X, W, B, eps):
159
187
  if X.device.type == "xpu":
160
188
  kernel_args["grf_mode"] = "large"
161
189
 
162
- _layer_norm_forward_kernel[(n_rows,)](
190
+ # Launch kernel with one thread block per row for optimal performance
191
+ grid = (n_rows,)
192
+ _layer_norm_forward_kernel[grid](
163
193
  Y,
164
194
  Y.stride(0),
165
195
  X,
@@ -176,35 +206,43 @@ def layer_norm_forward(X, W, B, eps):
176
206
  eps,
177
207
  BLOCK_SIZE=BLOCK_SIZE,
178
208
  num_warps=num_warps,
179
- **kernel_args, # XPU-specific optimization
209
+ **kernel_args,
180
210
  )
211
+
181
212
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
213
 
183
214
 
184
215
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
216
+ """
217
+ Args:
218
+ dY: Gradient of output
219
+ X: Input tensor
220
+ W: Weight tensor
221
+ B: Bias tensor
222
+ Mean: Pre-computed mean
223
+ RSTD: Pre-computed reciprocal standard deviation
224
+
225
+ Returns:
226
+ Tuple of (input_grad, weight_grad, bias_grad)
227
+ """
185
228
  shape = dY.shape
186
229
  dim = shape[-1]
187
230
  dY = dY.view(-1, dim)
188
231
  n_rows, n_cols = dY.shape
189
232
 
190
- sm_count = 1
191
- if X.device.type == "cuda":
192
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
193
- elif X.device.type == "xpu":
194
- sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
195
-
233
+ # Allocate gradient tensors
196
234
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
235
+ # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
236
+ grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
237
+ DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
238
+ DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
199
239
 
240
+ # Calculate optimal block size and warp configuration
200
241
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
242
  if n_cols > BLOCK_SIZE:
202
- raise RuntimeError(
203
- f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
- )
243
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
205
244
 
206
- rows_per_program = math.ceil(n_rows / sm_count)
207
- grid = (sm_count,)
245
+ # Determine dtype for triton operations
208
246
  triton_dtype = (
209
247
  tl.float32
210
248
  if X.dtype == torch.float32
@@ -212,41 +250,40 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
212
250
  if X.dtype == torch.bfloat16
213
251
  else tl.float16
214
252
  if X.dtype == torch.float16
215
- else tl.float32 # fallback to float32 for other types
253
+ else tl.float32 # fallback
216
254
  )
217
255
 
256
+ # Use float32 for atomic operations if bfloat16 is not supported
257
+ atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
258
+
259
+ kernel_args = {"num_warps": num_warps}
218
260
  # XPU-specific optimization
219
- kernel_args = {}
220
261
  if X.device.type == "xpu":
221
262
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
263
 
264
+ # Launch kernel with one thread block per row for optimal performance
265
+ grid = (n_rows,)
223
266
  _layer_norm_backward_kernel[grid](
224
267
  X,
225
268
  W,
226
269
  Mean,
227
270
  RSTD,
228
271
  DX,
229
- _DW,
230
- _DB,
272
+ DW,
273
+ DB,
231
274
  dY,
232
275
  X.stride(0),
233
276
  DX.stride(0),
234
- _DW.stride(0),
235
- _DB.stride(0),
236
277
  dY.stride(0),
237
- n_rows,
238
278
  n_cols,
239
- rows_per_program,
240
279
  BLOCK_SIZE=BLOCK_SIZE,
241
280
  dtype=triton_dtype,
242
- **kernel_args, # XPU-specific optimization
281
+ atomic_dtype=atomic_dtype,
282
+ **kernel_args,
243
283
  )
244
284
 
245
- DW = _DW.sum(dim=0).to(W.dtype)
246
- DB = _DB.sum(dim=0).to(W.dtype)
247
-
248
285
  DX = DX.view(*shape)
249
- return DX, DW, DB
286
+ return DX, DW.to(W.dtype), DB.to(W.dtype)
250
287
 
251
288
 
252
289
  class LigerLayerNormFunction(torch.autograd.Function):
@@ -0,0 +1,207 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from torch.nn.modules.utils import _pair
7
+
8
+ from liger_kernel.ops.softmax import _softmax_forward
9
+ from liger_kernel.ops.sparsemax import _sparsemax_backward
10
+ from liger_kernel.ops.sparsemax import _sparsemax_forward
11
+ from liger_kernel.ops.utils import calculate_settings
12
+ from liger_kernel.ops.utils import ensure_contiguous
13
+
14
+
15
+ @triton.jit
16
+ def _mask_fwd_kernel(
17
+ scores_ptr,
18
+ out_ptr,
19
+ stride_b,
20
+ stride_m,
21
+ stride_n,
22
+ L,
23
+ mask_val: tl.constexpr,
24
+ BLOCK: tl.constexpr,
25
+ num_warps: tl.constexpr,
26
+ ):
27
+ row_block = tl.program_id(0)
28
+ col_block = tl.program_id(1)
29
+ batch_id = tl.program_id(2)
30
+
31
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
32
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
33
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
34
+
35
+ base = scores_ptr + batch_id * stride_b
36
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
37
+ future = col_idx[None, :] > row_idx[:, None]
38
+ mask_load = in_bounds & ~future
39
+ out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca")
40
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs")
41
+
42
+
43
+ @triton.jit
44
+ def _mask_bwd_kernel(
45
+ grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr
46
+ ):
47
+ row_block = tl.program_id(0)
48
+ col_block = tl.program_id(1)
49
+ batch_id = tl.program_id(2)
50
+
51
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
52
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
53
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
54
+
55
+ base = grad_in_ptr + batch_id * stride_b
56
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
57
+ grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca")
58
+
59
+ future = col_idx[None, :] > row_idx[:, None]
60
+ zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype)
61
+ out = tl.where(future, zero, grad_vals)
62
+
63
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb")
64
+
65
+
66
+ def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor:
67
+ *batch, L, _ = scores.shape
68
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
69
+ scores_f = scores.view(N, L, L)
70
+ out = torch.empty_like(scores_f)
71
+
72
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
73
+ BLOCK_SIZE, num_warps = calculate_settings(L)
74
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
75
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps)
76
+ return out.view(*batch, L, L)
77
+
78
+
79
+ def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor:
80
+ *batch, L, _ = grad.shape
81
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
82
+ grad_f = grad.view(N, L, L)
83
+ out = torch.empty_like(grad_f)
84
+
85
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
86
+ BLOCK_SIZE, num_warps = calculate_settings(L)
87
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
88
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
89
+ return out.view(*batch, L, L)
90
+
91
+
92
+ def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor:
93
+ *batch, L, _ = scores.shape
94
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
95
+ scores_f = scores.view(N, L, L)
96
+ out = torch.empty_like(scores_f)
97
+
98
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
99
+ BLOCK_SIZE, num_warps = calculate_settings(L)
100
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
101
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps)
102
+ return out.view(*batch, L, L)
103
+
104
+
105
+ def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor:
106
+ *batch, L, _ = grad.shape
107
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
108
+ grad_f = grad.view(N, L, L)
109
+ out = torch.empty_like(grad_f)
110
+
111
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
112
+ BLOCK_SIZE, num_warps = calculate_settings(L)
113
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
114
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
115
+ return out.view(*batch, L, L)
116
+
117
+
118
+ class LigerMultiTokenAttentionFunction(torch.autograd.Function):
119
+ @staticmethod
120
+ @ensure_contiguous
121
+ def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False):
122
+ scores_inf = _mask_inf_forward(scores)
123
+
124
+ out_flat_sparse = None
125
+ activation_output = None
126
+
127
+ ctx.sparse = sparse
128
+
129
+ if sparse:
130
+ if scores_inf.dtype != torch.float32:
131
+ raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores")
132
+ probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1)
133
+ activation_output = probs_sparse
134
+ ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias)
135
+ ctx.out_flat_sparse_saved = True
136
+ else:
137
+ probs_softmax, _, _, _ = _softmax_forward(scores_inf)
138
+ activation_output = probs_softmax
139
+ ctx.save_for_backward(scores_inf, activation_output, weight, bias)
140
+ ctx.out_flat_sparse_saved = False
141
+
142
+ out_conv = F.conv2d(
143
+ activation_output,
144
+ weight,
145
+ bias,
146
+ stride=stride,
147
+ padding=padding,
148
+ dilation=dilation,
149
+ groups=groups,
150
+ )
151
+
152
+ out = _mask_zero_forward(out_conv)
153
+
154
+ ctx.stride = _pair(stride)
155
+ ctx.padding = _pair(padding)
156
+ ctx.dilation = _pair(dilation)
157
+ ctx.groups = groups
158
+ ctx.dim = -1
159
+
160
+ return out
161
+
162
+ @staticmethod
163
+ @ensure_contiguous
164
+ def backward(ctx, grad_out):
165
+ if ctx.out_flat_sparse_saved:
166
+ scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors
167
+ else:
168
+ scores_inf, activation_output, weight, bias = ctx.saved_tensors
169
+ out_flat_sparse = None
170
+
171
+ use_sparsemax = ctx.sparse
172
+ dim = ctx.dim
173
+ stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
174
+
175
+ grad_conv = _mask_zero_backward(grad_out)
176
+
177
+ grad_probs = F.conv_transpose2d(
178
+ grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups
179
+ )
180
+
181
+ grad_weight = torch.nn.grad.conv2d_weight(
182
+ input=activation_output,
183
+ weight_size=weight.shape,
184
+ grad_output=grad_conv,
185
+ stride=stride,
186
+ padding=padding,
187
+ dilation=dilation,
188
+ groups=groups,
189
+ )
190
+ grad_bias = None
191
+ if bias is not None:
192
+ grad_bias = grad_conv.sum(dim=(0, 2, 3))
193
+
194
+ grad_scores_inf = None
195
+ if use_sparsemax:
196
+ if not ctx.out_flat_sparse_saved or out_flat_sparse is None:
197
+ raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.")
198
+ grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim)
199
+ else:
200
+ grad_probs_cont = grad_probs
201
+ probs_cont = activation_output
202
+ dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True)
203
+ grad_scores_inf = probs_cont * (grad_probs_cont - dot)
204
+
205
+ grad_scores = _mask_inf_backward(grad_scores_inf)
206
+
207
+ return (grad_scores, grad_weight, grad_bias, None, None, None, None, None)